From 486c2a22a5117f72823b93d6a820761f9683b9b3 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Wed, 19 Jan 2022 09:16:41 +0530 Subject: [PATCH 01/63] Added functionality support for dzgemm AMD-Internal: [SWLCSG-1012] Change-Id: I2eac3131d2dcd534f84491289cbd3fe7fb7de3da --- frame/1m/packm/bli_packm_blk_var1.c | 5 +++-- frame/3/bli_l3_check.c | 7 +++++-- frame/3/gemm/bli_gemm_packab.c | 12 +++++++++++- frame/compat/bla_gemm.c | 8 +++----- frame/compat/bla_gemm.h | 4 ++-- frame/include/bli_type_defs.h | 5 +++-- test/test_gemm.c | 8 +++----- 7 files changed, 30 insertions(+), 19 deletions(-) diff --git a/frame/1m/packm/bli_packm_blk_var1.c b/frame/1m/packm/bli_packm_blk_var1.c index 87f8df4f7d..c720317b96 100644 --- a/frame/1m/packm/bli_packm_blk_var1.c +++ b/frame/1m/packm/bli_packm_blk_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -159,7 +159,8 @@ void bli_packm_blk_var1 // Treatment of kappa (ie: packing during scaling) depends on // whether we are executing an induced method. - if ( bli_is_nat_packed( schema ) ) + // For dzgemm, scale alpha during packing. + if ( bli_is_nat_packed( schema ) && cntl && bli_cntl_family(cntl) != BLIS_GEMM_MD) { // This branch is for native execution, where we assume that // the micro-kernel will always apply the alpha scalar of the diff --git a/frame/3/bli_l3_check.c b/frame/3/bli_l3_check.c index 945b267fda..43ba867283 100644 --- a/frame/3/bli_l3_check.c +++ b/frame/3/bli_l3_check.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -323,8 +324,10 @@ void bli_gemm_basic_check // When mixing datatypes, make sure that alpha does not have a non-zero // imaginary component. - if ( bli_obj_dt( c ) != bli_obj_dt( a ) || - bli_obj_dt( c ) != bli_obj_dt( b ) || + // To support dzgemm, we continue execution when datatypes of C and A + // do not match instead of aborting with an error message. + // Non-zero imaginary component of alpha is handled while packing B. + if ( bli_obj_dt( c ) != bli_obj_dt( b ) || bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) if ( !bli_obj_imag_is_zero( alpha ) ) { diff --git a/frame/3/gemm/bli_gemm_packab.c b/frame/3/gemm/bli_gemm_packab.c index 3dfed88478..6828725546 100644 --- a/frame/3/gemm/bli_gemm_packab.c +++ b/frame/3/gemm/bli_gemm_packab.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,9 +91,14 @@ void bli_gemm_packb ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_5); - + obj_t b_pack; + // BY setting family id to BLIS_GEMM_MD, we indicate packing kernels + // to scale alpha while packing. + if(bli_obj_dt(c) != bli_obj_dt(a)) + bli_cntl_set_family(BLIS_GEMM_MD, cntl); + // Pack matrix B according to the control tree node. bli_l3_packm ( @@ -103,6 +109,10 @@ void bli_gemm_packb cntl, thread ); + // Once packing of B matrix is done, fall back to GEMM execution. + if(bli_obj_dt(c) != bli_obj_dt(a)) + bli_cntl_set_family(BLIS_GEMM, cntl); + // Proceed with execution using packed matrix B. bli_gemm_int diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 50aa931a82..3cc7845739 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -804,9 +804,7 @@ INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) INSERT_GENTFUNC_BLAS( gemm,gemm ) #endif -// Observed a regression in dgemm with this function addition. -// Disabling temporarily. -#if 0 +#if 1 void dzgemm_ ( const f77_char* transa, @@ -883,7 +881,7 @@ void dzgemm_ bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); diff --git a/frame/compat/bla_gemm.h b/frame/compat/bla_gemm.h index 25aef8d11f..c9ea83149a 100644 --- a/frame/compat/bla_gemm.h +++ b/frame/compat/bla_gemm.h @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,8 +55,7 @@ BLIS_EXPORT_BLAS void PASTEF77(ch,blasname) \ ); #ifdef BLIS_ENABLE_BLAS -// Disabling temporarily -#if 0 +#if 1 BLIS_EXPORT_BLAS void dzgemm_ ( const f77_char* transa, \ diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 770f5c5378..1a3dea1d3d 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -931,10 +931,11 @@ typedef enum BLIS_TRMM, BLIS_TRSM, BLIS_GEMMT, + BLIS_GEMM_MD, BLIS_NOID } opid_t; -#define BLIS_NUM_LEVEL3_OPS 11 +#define BLIS_NUM_LEVEL3_OPS 12 // -- Blocksize ID type -- diff --git a/test/test_gemm.c b/test/test_gemm.c index 772d73c7b1..25fc5e3d8d 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -382,8 +382,7 @@ int main( int argc, char** argv ) cp, ldc ); #else -//Disabled dzgemm function temporarily. -#if 0 +#if 1 if( bli_is_double( dt_a ) ) { dzgemm_( @@ -401,7 +400,6 @@ int main( int argc, char** argv ) } else { -#else zgemm_( &f77_transa, &f77_transb, &mm, @@ -412,7 +410,7 @@ int main( int argc, char** argv ) bp, (f77_int*)&ldb, betap, cp, (f77_int*)&ldc ); -// } + } #endif #endif } From ceb5771f3b92151f70489d5f34e49d780dbd0447 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 24 Jan 2022 20:30:25 +0530 Subject: [PATCH 02/63] Updated windows build system to define BLIS_CONFIG_EPYC flag. All AMD specific optimization in BLIS are enclosed in BLIS_CONFIG_EPYC pre-preprocessor, this was not defined in CMake which are resulting in overall lower performance. Updated version number to 3.1.1 Change-Id: I9848b695a599df07da44e77e71a64414b28c75b9 --- CMakeLists.txt | 6 +++++- kernels/zen/2/CMakeLists.txt | 3 ++- so_version | 2 +- version | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8d892463a7..5018724656 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## cmake_minimum_required(VERSION 3.0.0) @@ -34,17 +34,20 @@ endif () if(${AOCL_BLIS_FAMILY} STREQUAL "zen") add_definitions(-DBLIS_FAMILY_ZEN) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") add_definitions(-DBLIS_FAMILY_ZEN2) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_FAMILY_ZEN3) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) @@ -53,6 +56,7 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(AOCL_BLIS_ZEN FALSE) add_definitions(-DBLIS_FAMILY_AMDZEN) + add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index 480837c023..dfa7c0b750 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -1,8 +1,9 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c ) diff --git a/so_version b/so_version index a831c0e579..b1f189286c 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.0 +1.1 diff --git a/version b/version index 0c6173b5f1..1795fa298a 100644 --- a/version +++ b/version @@ -1,2 +1,2 @@ -3.1.0 +3.1.1 From 5a098ab2a6514db0e4211f0126536f3a06bbf835 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Tue, 25 Jan 2022 13:53:03 +0530 Subject: [PATCH 03/63] Updated windows build system. We were using add_compile_options(-Xclang -fopenmp) statement to set omp library compiler flags on MSVC using cmake. Observing there is an performance regression because of the compiler version which is using in MSVC(clang 10), so removing it from the windows build system and configuring the compiler version(clang 13) and compiler options manually on MSVC gui to gain a performance on matlab bench. Change-Id: I37d778abdceb7c1fae9b1caaeea8adb114677dd2 --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5018724656..d6885e3a38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -264,7 +264,6 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) - add_compile_options(-Xclang -fopenmp) else() message (FATAL_ERROR "Openmp Not Found") endif() From 3eb97177922f27d8bfa9376a2946bdee897400dc Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Thu, 27 Jan 2022 13:01:22 +0530 Subject: [PATCH 04/63] Fixed a bug in deriving dimensions from objects in bli_gemm_front.c file Change-Id: I4cc38d8f1dee277f99b5532e4645e0e1bc5b31cb --- frame/3/gemm/bli_gemm_front.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index a065156bbf..c782559167 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -176,7 +176,7 @@ void bli_gemm_front dim_t m_dim_local = bli_obj_length( &c_local ); dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width( &a_local ); + dim_t k_dim_local = bli_obj_width_after_trans( &a_local ); #ifdef BLIS_CONFIG_EPYC // Regression observed in sgemm native path in cases where m >= 4 * n // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit From d01a09dfcabfc5eff4d46941a5d40fa28d7a2d7b Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Tue, 23 Nov 2021 08:33:27 -0600 Subject: [PATCH 05/63] Optimized dsymv implementation -Implemented hemv framework calls for lower and upper kernel variants. -hemv computation is implemented in two parts. One part operate on triangular part of matrix and the remaining part is computed by dotxfaxpyf kernel. -First part performs dotxf and axpyf operation on triangular part of matrix in chunk of 8x8. Two separate helper function for doing so are implemented for lower and upper kernels respectively. -Second part is ddotxaxpyf fused kernel, which performs dotxf and axpyf operation alltogether on non-triangular part of matrix in chunk of 4x8. -Implementation efficiently uses cache memory while computing for optimal performance. Change-Id: Id603031b4578e87a92c6b77f710c647acc195c8e --- frame/2/hemv/bli_hemv_unf_var1.c | 197 +++++++ frame/2/hemv/bli_hemv_unf_var3.c | 195 +++++++ kernels/zen/1f/bli_dotxaxpyf_int_8.c | 735 +++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 11 + 4 files changed, 1138 insertions(+) create mode 100644 kernels/zen/1f/bli_dotxaxpyf_int_8.c diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index d36dc00988..ccb39b3485 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -215,5 +216,201 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC + +void post_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + +void bli_dhemv_unf_var1 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A10; + double* A11; + double* a10t; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* chi11; + double* y0; + double* y1; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_behind; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + /* Query the context for the kernel function pointer and fusing + * factor. */ + /* Assign kernel function pointer and fusing factor. */ + arch_t id = bli_arch_query_id(); + bool bamdzen = ((id == BLIS_ARCH_ZEN4) ||(id == BLIS_ARCH_ZEN3) + || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); + if (bamdzen) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_behind = i; + A10 = a + (i )*rs_at + (0 )*cs_at; + A11 = a + (i )*rs_at + (i )*cs_at; + x0 = x + (0 )*incx; + x1 = x + (i )*incx; + y0 = y + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_behind, + f, + alpha, + A10, cs_at, rs_at, + x0, incx, + x1, incx, + one, + y1, incy, + y0, incy, + cntx + ); + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, *chi11, + conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,axpys)( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + } +} +GENTFUNC(float, s, hemv_unf_var1) +GENTFUNC(scomplex, c, hemv_unf_var1) +GENTFUNC(dcomplex, z, hemv_unf_var1) +#else INSERT_GENTFUNC_BASIC0( hemv_unf_var1 ) +#endif diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index d8db9bc78a..6ed18efea4 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -5,6 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -215,5 +216,199 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC +void bli_dhemv_unf_var3 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A11; + double* A21; + double* a10t; + double* alpha11; + double* a21; + double* x1; + double* x2; + double* chi11; + double* y1; + double* y2; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_ahead; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + arch_t id = bli_arch_query_id(); + bool bamdzen = ((id == BLIS_ARCH_ZEN4) || (id == BLIS_ARCH_ZEN3) + || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); + if (bamdzen) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_ahead = m - i - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + y1 = y + (i )*incy; + y2 = y + (i+f)*incy; + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, + *chi11, conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + { + for ( j = 0; j < f_behind; ++j ) + { + PASTEMAC(d,axpys) + ( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + } + } + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_ahead, + f, + alpha, + A21, rs_at, cs_at, + x2, incx, + x1, incx, + one, + y1, incy, + y2, incy, + cntx + ); + } +} + +GENTFUNC(float, s, hemv_unf_var3) +GENTFUNC(scomplex, c, hemv_unf_var3) +GENTFUNC(dcomplex, z, hemv_unf_var3) +#else INSERT_GENTFUNC_BASIC0( hemv_unf_var3 ) +#endif diff --git a/kernels/zen/1f/bli_dotxaxpyf_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_int_8.c new file mode 100644 index 0000000000..b24aab7571 --- /dev/null +++ b/kernels/zen/1f/bli_dotxaxpyf_int_8.c @@ -0,0 +1,735 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "immintrin.h" + +typedef union{ + __m256d v; + double d[4] __attribute__((aligned(64))); +}vec; + +/** + * bli_pre_hemv_lower_8x8 is a helper function which computes + * "y = y + alpha * a * x" + * dotxf and axpyf of triangular matrix with vector + * for lower triangular matrix cases. + * Computes 8 elements of Y vector by dot product + * of 8 elements of x vector with 8x8 tile of A matrix + * and axpy computation of each x vector elements with + * each column of 8x8 A matrix tile. + +*/ +void bli_pre_hemv_8x8(double *a, double *x, double *y, double *alpha, + dim_t cs_a, dim_t rs_a) +{ + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9; + __m256d ymm10, ymm11, ymm12; + double alpha_chi[8] = {0}; + /*Broadcast alpha*/ + ymm9 = _mm256_broadcast_sd(alpha); + + /** + * Scaling vector x with alpha + * to gather alpha_chi elements + * arranged in one buffer. + */ + ymm10 = _mm256_loadu_pd(x); + ymm11 = _mm256_loadu_pd(x + 4); + ymm10 = _mm256_mul_pd(ymm9, ymm10); + ymm11 = _mm256_mul_pd(ymm9, ymm11); + _mm256_storeu_pd(alpha_chi, ymm10); + _mm256_storeu_pd(alpha_chi + 4, ymm11); + + /*Load y vector*/ + ymm10 = _mm256_loadu_pd(y); + ymm11 = _mm256_loadu_pd(y + 4); + + //Col 0 computation + /*Broadcasts chi and multiplies with alpha to get alpha chi*/ + ymm12 = _mm256_broadcast_sd(alpha_chi); + /*Load first column of A matrix*/ + ymm0 = _mm256_loadu_pd(a); + ymm1 = _mm256_loadu_pd(a + 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 1 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 1); + /** + * pack the data in following manner into ymm register + * Since it is computing 2nd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 + * --- --- + x x + --- x + --- x + */ + ymm3 = _mm256_broadcast_sd(a + 1); + ymm0 = _mm256_loadu_pd(a + cs_a * 1); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 1); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 2 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 2); + /** + * pack the data in following manner into ymm register + * Since it is computing 3rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 + * --- --- --- + x x --- + --- --- x + --- --- x + */ + ymm3 = _mm256_broadcast_sd(a + 2); + ymm4 = _mm256_broadcast_sd(a + 2 + cs_a); + ymm0 = _mm256_loadu_pd(a + cs_a * 2); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x2); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 2); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 3 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 3); + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 + * --- --- --- --- + x x x --- + --- --- --- x + */ + ymm3 = _mm256_broadcast_sd(a + 3); + ymm4 = _mm256_broadcast_sd(a + 3 + cs_a); + ymm5 = _mm256_broadcast_sd(a + 3 + cs_a * 2); + ymm0 = _mm256_loadu_pd(a + cs_a * 3); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x1); + ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x2); + ymm0 = _mm256_blend_pd(ymm0, ymm5, 0x4); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 3); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Transpose 4x4 tile of matrix A, + * for remainder column computation. + */ + ymm0 = _mm256_loadu_pd(a+4 + cs_a * 0); + ymm1 = _mm256_loadu_pd(a+4 + cs_a * 1); + ymm2 = _mm256_loadu_pd(a+4 + cs_a * 2); + ymm3 = _mm256_loadu_pd(a+4 + cs_a * 3); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //Transposed col 1 + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //Transposed col 3 + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //Transposed col 2 + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //Transposed col 4 + + //Col 4 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 4); + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 col-4 + * --- --- --- --- --- + x x x x --- + --- --- --- --- --- + --- --- --- --- --- + */ + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm6, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 5 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 5th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 + * --- --- + x x + --- x + --- x + + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 5); + ymm3 = _mm256_broadcast_sd(a + 5 + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 5); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm10 = _mm256_fmadd_pd(ymm12, ymm7, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 6 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 6th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 col-6 + * --- --- --- + x x --- + --- --- x + --- --- x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 6); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 6); + ymm3 = _mm256_broadcast_sd(a + 6 + cs_a * 4); + ymm4 = _mm256_broadcast_sd(a + 6 + cs_a * 5); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm10 = _mm256_fmadd_pd(ymm12, ymm8, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 7 computation + /** + * Packs the data in similar manner as shown + * for col 0-4 computation, along with + * packing all 7th elements from col 0 - 4 + * in other ymm register. + * col-4 col-5 col-6 col-7 + * --- --- --- --- + x x x --- + --- --- --- x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 7); + ymm1 = _mm256_loadu_pd(a + 4 + cs_a * 7); + ymm3 = _mm256_broadcast_sd(a + 7 + cs_a * 4); + ymm4 = _mm256_broadcast_sd(a + 7 + cs_a * 5); + ymm5 = _mm256_broadcast_sd(a + 7 + cs_a * 6); + ymm1 = _mm256_blend_pd(ymm1, ymm3, 0x1); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm9, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Computed result of vector y is available in ymm10, ymm11. + * Storing the result back from ymm register into y vector for + * further computaion. + */ + _mm256_storeu_pd(y, ymm10); + _mm256_storeu_pd(y + 4, ymm11); +} + + +/** + * bli_post_hemv_lower_8x8 is a helper function which computes + * "y = y + alpha * a * x" + * dotxf and axpyf of triangular matrix with vector + * for upper triangular matrix cases. + * Computes 8 elements of Y vector by dot product + * of 8 elements of x vector with 8x8 tile of A matrix + * and axpy computation of each x vector elements with + * each column of 8x8 A matrix tile. +*/ +void bli_post_hemv_8x8(double *a, double *x, double *y, double *alpha, + dim_t cs_a, dim_t rs_a) +{ + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9; + __m256d ymm10, ymm11, ymm12; + double alpha_chi[8] = {0}; + + ymm9 = _mm256_broadcast_sd(alpha); + + ymm10 = _mm256_loadu_pd(x); + ymm11 = _mm256_loadu_pd(x + 4); + ymm10 = _mm256_mul_pd(ymm9, ymm10); + ymm11 = _mm256_mul_pd(ymm9, ymm11); + _mm256_storeu_pd(alpha_chi, ymm10); + _mm256_storeu_pd(alpha_chi + 4, ymm11); + + ymm10 = _mm256_loadu_pd(y); + ymm11 = _mm256_loadu_pd(y + 4); + + ymm0 = _mm256_loadu_pd(a + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + cs_a * 5); + ymm2 = _mm256_loadu_pd(a + cs_a * 6); + ymm3 = _mm256_loadu_pd(a + cs_a * 7); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + //Col 0 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-0 col-1 col-2 col-3 + * x x x x + --- + --- + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi); + ymm0 = _mm256_loadu_pd(a); + ymm1 = _mm256_broadcast_sd(a + cs_a * 1); + ymm2 = _mm256_broadcast_sd(a + cs_a * 2); + ymm3 = _mm256_broadcast_sd(a + cs_a * 3); + ymm0 = _mm256_blend_pd(ymm0, ymm1, 0x2); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x4); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm6, ymm11); + + //Col 1 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-1 col-2 col-3 + * x x x + x + --- + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 1); + ymm0 = _mm256_loadu_pd(a + cs_a * 1); + ymm2 = _mm256_broadcast_sd(a + cs_a * 2 + 1); + ymm3 = _mm256_broadcast_sd(a + cs_a * 3 + 1); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x4); + ymm0 = _mm256_blend_pd(ymm0, ymm3, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm7, ymm11); + + //Col 2 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-2 col-3 + * x x + x + x + --- + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 2); + ymm0 = _mm256_loadu_pd(a + cs_a * 2); + ymm2 = _mm256_broadcast_sd(a + cs_a * 3 + 2); + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm8, ymm11); + + //Col 3 computation + /** + * pack the data in following manner into ymm register + * Since it is computing 4rd column, packing to be done + * as shown below for ymm0: + * col-3 + * x + x + x + x + */ + ymm12 = _mm256_broadcast_sd(alpha_chi + 3); + ymm0 = _mm256_loadu_pd(a + cs_a * 3); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm9, ymm11); + + //Col 4 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 4); + ymm0 = _mm256_loadu_pd(a + cs_a * 4); + ymm1 = _mm256_loadu_pd(a + cs_a * 4 + 4); + ymm4 = _mm256_broadcast_sd(a + cs_a * 5 + 4); + ymm5 = _mm256_broadcast_sd(a + cs_a * 6 + 4); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 4); + ymm1 = _mm256_blend_pd(ymm1, ymm4, 0x2); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 5 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 5); + ymm0 = _mm256_loadu_pd(a + cs_a * 5); + ymm1 = _mm256_loadu_pd(a + cs_a * 5 + 4); + ymm5 = _mm256_broadcast_sd(a + cs_a * 6 + 5); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 5); + ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x4); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 6 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 6); + ymm0 = _mm256_loadu_pd(a + cs_a * 6); + ymm1 = _mm256_loadu_pd(a + cs_a * 6 + 4); + ymm6 = _mm256_broadcast_sd(a + cs_a * 7 + 6); + ymm1 = _mm256_blend_pd(ymm1, ymm6, 0x8); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + //Col 7 computation + ymm12 = _mm256_broadcast_sd(alpha_chi + 7); + ymm0 = _mm256_loadu_pd(a + cs_a * 7); + ymm1 = _mm256_loadu_pd(a + cs_a * 7 + 4); + ymm10 = _mm256_fmadd_pd(ymm12, ymm0, ymm10); + ymm11 = _mm256_fmadd_pd(ymm12, ymm1, ymm11); + + /** + * Computed result of vector y is available in ymm10, ymm11. + * Storing the result back from ymm register into y vector for + * further computaion. + */ + _mm256_storeu_pd(y, ymm10); + _mm256_storeu_pd(y + 4, ymm11); +} + + +/** + * ddotxaxpyf kernel performs dot and apxy function all togather + * on a tile of 4x8 size. + * x_trsv holds 4 elements of vector x, a_tile[0-7] holds + * 4x8 tile of A matrix. + * Following equations are solved in a way represented + * y1 = y1 + alpha * A21' * x2; (dotxf) + y2 = y2 + alpha * A21 * x1; (axpyf) + + * B1 B2 B3 B4 B5 B6 B7 B8 + * (broadcast elements of [x*alpha] vector) + * tile 0 1 2 3 4 5 6 7 + * x_trsv[0] A00 A01 A02 A03 => rho0 | A04 A05 A06 A07 => rho4 + * x_trsv[1] A10 A11 A12 A13 => rho1 | A14 A15 A16 A17 => rho5 + * x_trsv[2] A20 A21 A22 A23 => rho2 | A24 A25 A26 A27 => rho6 + * x_trsv[3] A30 A31 A32 A33 => rho3 | A34 A35 A36 A37 => rho7 + || || || || || || || || + \/ \/ \/ \/ \/ \/ \/ \/ + += += += += += += += += + z_vec z_vec z_vec z_vec z_vec z_vec z_vec z_vec + * + * + */ +void bli_ddotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + double* restrict alpha, + double* restrict a, inc_t inca, inc_t lda, + double* restrict w, inc_t incw, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + double* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + /* A is m x n. */ + /* y = beta * y + alpha * A^T w; */ + /* z = z + alpha * A x; */ + if ((inca == 1) && (incw == 1) && (incx == 1) + && (incy == 1) && (incz == 1) && (b_n == 8)) + { + __m256d r0, r1; + r0 = _mm256_setzero_pd(); + r1 = _mm256_setzero_pd(); + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early*/ + if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) return; + + dim_t row = 0; + dim_t iter = m/4; + dim_t rem = m%4; + if(iter) + { + vec x_trsv, x_hemvB1, x_hemvB2, x_hemvB3, x_hemvB4; + vec x_hemvB5, x_hemvB6, x_hemvB7, x_hemvB8; + + vec a_tile0, a_tile1, a_tile2, a_tile3; + vec a_tile4, a_tile5, a_tile6, a_tile7; + + vec rho0, rho1, rho2, rho3; + vec rho4, rho5, rho6, rho7; + + __m256d z_vec; + + /** + * Load [x vector * alpha], broadcast each element into + * different ymm registers. To perform axpyf operation + * with 4x8 tile of A matrix. + */ + + x_hemvB1.v = _mm256_set1_pd(x[0*incx] * (*alpha)); + x_hemvB2.v = _mm256_set1_pd(x[1*incx] * (*alpha)); + x_hemvB3.v = _mm256_set1_pd(x[2*incx] * (*alpha)); + x_hemvB4.v = _mm256_set1_pd(x[3*incx] * (*alpha)); + + x_hemvB5.v = _mm256_set1_pd(x[4*incx] * (*alpha)); + x_hemvB6.v = _mm256_set1_pd(x[5*incx] * (*alpha)); + x_hemvB7.v = _mm256_set1_pd(x[6*incx] * (*alpha)); + x_hemvB8.v = _mm256_set1_pd(x[7*incx] * (*alpha)); + + /** + * clear rho register which holds result of + * fmadds for dotxf operation. + * Once micro tile is computed, horizontal addition + * of all rho's will provide us with the result of + * dotxf opereation. + */ + rho0.v = _mm256_setzero_pd(); + rho1.v = _mm256_setzero_pd(); + rho2.v = _mm256_setzero_pd(); + rho3.v = _mm256_setzero_pd(); + rho4.v = _mm256_setzero_pd(); + rho5.v = _mm256_setzero_pd(); + rho6.v = _mm256_setzero_pd(); + rho7.v = _mm256_setzero_pd(); + + for(; (row + 3) < m; row+= 4) + { + a_tile0.v = _mm256_loadu_pd((double *) + &a[row + 0 * lda] ); + a_tile1.v = _mm256_loadu_pd((double *) + &a[row + 1 * lda] ); + a_tile2.v = _mm256_loadu_pd((double *) + &a[row + 2 * lda] ); + a_tile3.v = _mm256_loadu_pd((double *) + &a[row + 3 * lda] ); + a_tile4.v = _mm256_loadu_pd((double *) + &a[row + 4 * lda] ); + a_tile5.v = _mm256_loadu_pd((double *) + &a[row + 5 * lda] ); + a_tile6.v = _mm256_loadu_pd((double *) + &a[row + 6 * lda] ); + a_tile7.v = _mm256_loadu_pd((double *) + &a[row + 7 * lda] ); + + x_trsv.v = _mm256_loadu_pd((double *) &w[row]); + z_vec = _mm256_loadu_pd((double *) &z[row] ); + + //dot product operation + rho0.v = _mm256_fmadd_pd(a_tile0.v, + x_trsv.v, rho0.v); + rho4.v = _mm256_fmadd_pd(a_tile4.v, + x_trsv.v, rho4.v); + + rho1.v = _mm256_fmadd_pd(a_tile1.v, + x_trsv.v, rho1.v); + rho5.v = _mm256_fmadd_pd(a_tile5.v, + x_trsv.v, rho5.v); + + rho2.v = _mm256_fmadd_pd(a_tile2.v, + x_trsv.v, rho2.v); + rho6.v = _mm256_fmadd_pd(a_tile6.v, + x_trsv.v, rho6.v); + + rho3.v = _mm256_fmadd_pd(a_tile3.v, + x_trsv.v, rho3.v); + rho7.v = _mm256_fmadd_pd(a_tile7.v, + x_trsv.v, rho7.v); + + //axpy operation + z_vec = _mm256_fmadd_pd(a_tile0.v, + x_hemvB1.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile1.v, + x_hemvB2.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile2.v, + x_hemvB3.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile3.v, + x_hemvB4.v, z_vec); + + z_vec = _mm256_fmadd_pd(a_tile4.v, + x_hemvB5.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile5.v, + x_hemvB6.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile6.v, + x_hemvB7.v, z_vec); + z_vec = _mm256_fmadd_pd(a_tile7.v, + x_hemvB8.v, z_vec); + + _mm256_storeu_pd((double *)&z[row], z_vec); + } + /*Horizontal addition of rho's elements to compute + * the final dotxf result. + */ + rho0.v = _mm256_hadd_pd( rho0.v, rho1.v ); + rho2.v = _mm256_hadd_pd( rho2.v, rho3.v ); + rho4.v = _mm256_hadd_pd( rho4.v, rho5.v ); + rho6.v = _mm256_hadd_pd( rho6.v, rho7.v ); + + { + __m128d xmm0, xmm1; + + xmm0 = _mm256_extractf128_pd(rho0.v, 0); + xmm1 = _mm256_extractf128_pd(rho0.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r0 = _mm256_insertf128_pd(r0, xmm0, 0); + + xmm0 = _mm256_extractf128_pd(rho2.v, 0); + xmm1 = _mm256_extractf128_pd(rho2.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r0 = _mm256_insertf128_pd(r0, xmm0, 1); + + + xmm0 = _mm256_extractf128_pd(rho4.v, 0); + xmm1 = _mm256_extractf128_pd(rho4.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r1 = _mm256_insertf128_pd(r1, xmm0, 0); + + xmm0 = _mm256_extractf128_pd(rho6.v, 0); + xmm1 = _mm256_extractf128_pd(rho6.v, 1); + xmm0 = _mm_add_pd(xmm0, xmm1); + r1 = _mm256_insertf128_pd(r1, xmm0, 1); + } + } + if(rem) + { + double r[ 8 ]; + double ax[ 8 ]; + /** + * Computed dot product computation needs + * to be brought into the r buffer for + * corner cases, so that remainder computation + * can be updated in r buffer. + */ + _mm256_storeu_pd((double *)r, r0); + _mm256_storeu_pd( (double *)(r + 4), r1); + + PRAGMA_SIMD + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,scal2s) + ( *alpha, x[i], ax[i] ); + } + + PRAGMA_SIMD + for ( dim_t p = row; p < m; ++p ) + { + for ( dim_t i = 0; i < 8; ++i ) + { + PASTEMAC(d,axpys) + ( a[p + i*lda], + w[p], r[i] ); + PASTEMAC(d,axpyjs) + ( ax[i], + a[p + i*lda], z[p] ); + } + } + /** + * Final dot product computation needs be + * loaded into registers, for getting + * scaled by Alpha and finally be stored + * back into output vector. + */ + r0 = _mm256_loadu_pd((double const *)r); + r1 = _mm256_loadu_pd((double const *)(r + 4)); + } + + /** + * Storing the computed result after being + * scaled by Alpha into output vector. + */ + { + __m256d y0, y1, Alpha; + y0 = _mm256_loadu_pd(y); + y1 = _mm256_loadu_pd(y + 4); + Alpha = _mm256_broadcast_sd(alpha); + y0 = _mm256_fmadd_pd(Alpha, r0, y0); + y1 = _mm256_fmadd_pd(Alpha, r1, y1); + _mm256_storeu_pd(y, y0); + _mm256_storeu_pd(y+4, y1); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(d,type); + PASTECH(d,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(d,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 73104f817d..3bdbbed4e5 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -32,6 +32,14 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +// hemv helper function +void bli_pre_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + +void bli_post_hemv_8x8(double *a, double *x, + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) @@ -109,6 +117,9 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +// dotxaxpyf (intrinsics) +DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) + // -- level-2 ---------------------------------------------------------------- //gemv(scalar code) From 0a604ac2a1a17e2f6857cea87122f39a8b5a141d Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Tue, 14 Dec 2021 12:01:12 +0530 Subject: [PATCH 06/63] Improved DGEMV performance for smaller sizes - Introduced two new ddotxf functions with lower fuse factor. - Changed the DGEMV framework to use new kernels to improve problem decomposition. Change-Id: I523e158fd33260d06224118fbf74f2314e03a617 --- frame/2/gemv/bli_gemv_unf_var1.c | 211 +++-- kernels/zen/1f/bli_dotxf_zen_int_8.c | 1105 +++++++++++++++++++++----- kernels/zen/bli_kernels_zen.h | 2 + 3 files changed, 1066 insertions(+), 252 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 4f0054c1f1..085fe87c45 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -34,7 +34,6 @@ */ #include "blis.h" -#define BLIS_DGEMV_VAR1_FUSE 8 #undef GENTFUNC #define GENTFUNC( ctype, ch, varname ) \ @@ -121,30 +120,30 @@ void bli_dgemv_unf_var1 ) { - double* A1; - double* y1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; + double *A1; + double *y1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; //memory pool declarations for packing vector X. - mem_t mem_bufX; - rntm_t rntm; - double *x_buf = x; - inc_t buf_incx = incx; + mem_t mem_bufX; + rntm_t rntm; + double *x_buf = x; + inc_t buf_incx = incx; bli_init_once(); - if( cntx == NULL ) cntx = bli_gks_query_cntx(); + if (cntx == NULL) + cntx = bli_gks_query_cntx(); - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); + bli_set_dims_incs_with_trans(transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at); - conja = bli_extract_conj( transa ); + conja = bli_extract_conj(transa); - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. // This function is invoked on all architectures including ‘generic’. // Invoke architecture specific kernels only if we are sure that we are running on zen, // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). @@ -190,88 +189,154 @@ void bli_dgemv_unf_var1 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); return; } - + if (incx > 1) { - /* + /* Initialize mem pool buffer to NULL and size to 0 "buf" and "size" fields are assigned once memory is allocated from the pool in bli_membrk_acquire_m(). This will ensure bli_mem_is_alloc() will be passed on an allocated memory if created or a NULL . - */ - mem_bufX.pblk.buf = NULL; mem_bufX.pblk.block_size = 0; - mem_bufX.buf_type = 0; mem_bufX.size = 0; - mem_bufX.pool = NULL; + */ - /* In order to get the buffer from pool via rntm access to memory broker + mem_bufX.pblk.buf = NULL; + mem_bufX.pblk.block_size = 0; + mem_bufX.buf_type = 0; + mem_bufX.size = 0; + mem_bufX.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker is needed.Following are initializations for rntm */ - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); - //calculate the size required for n_elem double elements in vector X. - size_t buffer_size = n_elem * sizeof(double); + //calculate the size required for n_elem double elements in vector X. + size_t buffer_size = n_elem * sizeof(double); - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): get mem pool block\n" ); - #endif +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): get mem pool block\n"); +#endif - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufX ))) - { - x_buf = bli_mem_buffer(&mem_bufX); - - //pack X vector with non-unit stride to a temp buffer x_buf with unit stride - for(dim_t x_index = 0 ; x_index < n_elem ; x_index++) - { - *(x_buf + x_index) = *(x + (x_index * incx)) ; - } - // stride of vector x_buf =1 - buf_incx = 1; - } + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc(&mem_bufX))) + { + x_buf = bli_mem_buffer(&mem_bufX); + + //pack X vector with non-unit stride to a temp buffer x_buf with unit stride + for (dim_t x_index = 0; x_index < n_elem; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)); + } + // stride of vector x_buf =1 + buf_incx = 1; } - - for ( i = 0; i < n_iter; i += f ) + } + + dim_t fuse_factor = 8; + dim_t f_temp =0; + + if (n < 4) + { + fuse_factor = 2; + } else if (n < 8) + { + fuse_factor = 4; + } + + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + //A = a + i * row_increment + 0 * column_increment + A1 = a + (i)*rs_at; + y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + switch (f) { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR1_FUSE ); + case 8: - A1 = a + (i )*rs_at + (0 )*cs_at; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_ddotxf_zen_int_8 - ( + bli_ddotxf_zen_int_8( conja, conjx, n_elem, f, alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, + A1, cs_at, rs_at, + x_buf, buf_incx, beta, - y1, incy, - cntx - ); + y1, incy, + cntx); + + break; + default: + if (f < 4) + { + bli_ddotxf_zen_int_2( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + else + { + bli_ddotxf_zen_int_4( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } } - if ((incx > 1) && bli_mem_is_alloc( &mem_bufX )) + + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + if (f_temp < fuse_factor) { - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var1(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufX); + switch (fuse_factor) + { + case 8: + fuse_factor = 4; + break; + case 4: + fuse_factor = 2; + break; + } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + } + + if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); +#endif + // Return the buffer to pool + bli_membrk_release(&rntm, &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } void bli_sgemv_unf_var1 diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index 531a389b50..e25910fb4e 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -52,6 +52,14 @@ typedef union double d[4] __attribute__((aligned(64))); } v4df_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 2 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + // ----------------------------------------------------------------------------- void bli_sdotxf_zen_int_8 @@ -430,49 +438,46 @@ void bli_ddotxf_zen_int_8 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 8; - const dim_t n_elem_per_reg = 4; + const dim_t fuse_fac = 8; + const dim_t n_elem_per_reg = 4; // If the b_n dimension is zero, y is empty and there is no computation. - if ( bli_zero_dim1( b_n ) ) return; + if (bli_zero_dim1(b_n)) + return; // If the m dimension is zero, or if alpha is zero, the computation // simplifies to updating y. - if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) ) + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) { - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - b_n, - beta, - y, incy, - cntx - ); + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); return; } // If b_n is not equal to the fusing factor, then perform the entire // operation as a loop over dotxv. - if ( b_n != fuse_fac ) + if (b_n != fuse_fac) { - for ( dim_t i = 0; i < b_n; ++i ) + for (dim_t i = 0; i < b_n; ++i) { - double* a1 = a + (0 )*inca + (i )*lda; - double* x1 = x + (0 )*incx; - double* psi1 = y + (i )*incy; - - bli_ddotxv_zen_int - ( - conjat, - conjx, - m, - alpha, - a1, inca, - x1, incx, - beta, - psi1, - cntx - ); + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); } return; } @@ -493,115 +498,113 @@ void bli_ddotxf_zen_int_8 // distinguishes between (1) and (2). // Intermediate variables to hold the completed dot products - double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0, - rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + double rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0; - if ( inca == 1 && incx == 1 ) + if (inca == 1 && incx == 1) { const dim_t n_iter_unroll = 1; // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); + dim_t m_viter; + + // Calculate the number of vector iterations that can occur + // for the given unroll factors. + m_viter = (m) / (n_elem_per_reg * n_iter_unroll); // Set up pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + double *restrict x0 = x; + double *restrict av[8]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + av[4] = a + 4 * lda; + av[5] = a + 5 * lda; + av[6] = a + 6 * lda; + av[7] = a + 7 * lda; // Initialize b_n rho vector accumulators to zero. - v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); - v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); - v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); - v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); - v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); - v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); - v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); - v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rhov[8]; - v4df_t x0v; - v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v; + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); - // If there are vectorized iterations, perform them with vector - // instructions. - for ( dim_t i = 0; i < m_viter; ++i ) + v4df_t xv; + v4df_t avec[8]; + + for (dim_t i = 0; i < m_viter; ++i) { // Load the input values. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv.v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); // perform: rho?v += a?v * x0v; - rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v ); - rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v ); - rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v ); - rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v ); - rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v ); + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv.v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv.v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv.v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv.v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[4] + 0 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[5] + 0 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[6] + 0 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[7] + 0 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv.v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv.v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv.v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv.v, rhov[7].v); x0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - a5 += n_elem_per_reg * n_iter_unroll; - a6 += n_elem_per_reg * n_iter_unroll; - a7 += n_elem_per_reg * n_iter_unroll; + av[0] += n_elem_per_reg * n_iter_unroll; + av[1] += n_elem_per_reg * n_iter_unroll; + av[2] += n_elem_per_reg * n_iter_unroll; + av[3] += n_elem_per_reg * n_iter_unroll; + av[4] += n_elem_per_reg * n_iter_unroll; + av[5] += n_elem_per_reg * n_iter_unroll; + av[6] += n_elem_per_reg * n_iter_unroll; + av[7] += n_elem_per_reg * n_iter_unroll; } -#if 0 - rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; - rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3]; - rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3]; - rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3]; - rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3]; - rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3]; - rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3]; - rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3]; -#else // Sum the elements of a given rho?v. This computes the sum of // elements within lanes and stores the sum to both elements. - rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v ); - rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v ); - rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v ); - rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v ); - rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v ); - rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v ); - rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v ); - rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v ); + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + rhov[4].v = _mm256_hadd_pd(rhov[4].v, rhov[4].v); + rhov[5].v = _mm256_hadd_pd(rhov[5].v, rhov[5].v); + rhov[6].v = _mm256_hadd_pd(rhov[6].v, rhov[6].v); + rhov[7].v = _mm256_hadd_pd(rhov[7].v, rhov[7].v); // Manually add the results from above to finish the sum. - rho0 = rho0v.d[0] + rho0v.d[2]; - rho1 = rho1v.d[0] + rho1v.d[2]; - rho2 = rho2v.d[0] + rho2v.d[2]; - rho3 = rho3v.d[0] + rho3v.d[2]; - rho4 = rho4v.d[0] + rho4v.d[2]; - rho5 = rho5v.d[0] + rho5v.d[2]; - rho6 = rho6v.d[0] + rho6v.d[2]; - rho7 = rho7v.d[0] + rho7v.d[2]; -#endif + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + rho4 = rhov[4].d[0] + rhov[4].d[2]; + rho5 = rhov[5].d[0] + rhov[5].d[2]; + rho6 = rhov[6].d[0] + rhov[6].d[2]; + rho7 = rhov[7].d[0] + rhov[7].d[2]; + // Adjust for scalar subproblem. m -= n_elem_per_reg * n_iter_unroll * m_viter; a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */; x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */; - } - else if ( lda == 1 ) + + }else if (lda == 1) { const dim_t n_iter_unroll = 3; const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg; @@ -672,127 +675,871 @@ void bli_ddotxf_zen_int_8 a += n_iter_unroll * m_viter * inca; x += n_iter_unroll * m_viter * incx; } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + double *restrict a4 = a + 4 * lda; + double *restrict a5 = a + 5 * lda; + double *restrict a6 = a + 6 * lda; + double *restrict a7 = a + 7 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + const double a4c = *a4; + const double a5c = *a5; + const double a6c = *a6; + const double a7c = *a7; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + rho4 += a4c * x0c; + rho5 += a5c * x0c; + rho6 += a6c * x0c; + rho7 += a7c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + a5 += inca; + a6 += inca; + a7 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v4df_t rho0v, rho1v, y0v, y1v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + rho0v.d[2] = rho2; + rho0v.d[3] = rho3; + rho1v.d[0] = rho4; + rho1v.d[1] = rho5; + rho1v.d[2] = rho6; + rho1v.d[3] = rho7; + + // Broadcast the alpha scalar. + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); + y1v.v = _mm256_mul_pd(alphav.v, rho1v.v); + } else { - // No vectorization possible; use scalar iterations for the entire - // problem. + // Broadcast the beta scalar. + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); + y1v.v = _mm256_loadu_pd(y + 1 * n_elem_per_reg); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); + y1v.d[0] = *(y + 4 * incy); + y1v.d[1] = *(y + 5 * incy); + y1v.d[2] = *(y + 6 * incy); + y1v.d[3] = *(y + 7 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y1v.v = _mm256_mul_pd(betav.v, y1v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); + y1v.v = _mm256_fmadd_pd(alphav.v, rho1v.v, y1v.v); } - // Scalar edge case. + if (incy == 1) { - // Initialize pointers for x and the b_n columns of A (rows of A^T). - double* restrict x0 = x; - double* restrict a0 = a + 0*lda; - double* restrict a1 = a + 1*lda; - double* restrict a2 = a + 2*lda; - double* restrict a3 = a + 3*lda; - double* restrict a4 = a + 4*lda; - double* restrict a5 = a + 5*lda; - double* restrict a6 = a + 6*lda; - double* restrict a7 = a + 7*lda; + // Store the output. + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); + _mm256_storeu_pd((y + 1 * n_elem_per_reg), y1v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; + *(y + 4 * incy) = y1v.d[0]; + *(y + 5 * incy) = y1v.d[1]; + *(y + 6 * incy) = y1v.d[2]; + *(y + 7 * incy) = y1v.d[3]; + } +} - // If there are leftover iterations, perform them with scalar code. - for ( dim_t i = 0; i < m ; ++i ) + +void bli_ddotxf_zen_int_4 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 4; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) { - const double x0c = *x0; + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. - rho0 += a0c * x0c; - rho1 += a1c * x0c; - rho2 += a2c * x0c; - rho3 += a3c * x0c; - rho4 += a4c * x0c; - rho5 += a5c * x0c; - rho6 += a6c * x0c; - rho7 += a7c * x0c; + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). - x0 += incx; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], m_left = m, i; + + // Calculate the number of vector iterations that can occur for + // various unroll factors. + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[4]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + av[2] = a + 2 * lda; + av[3] = a + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[16]; + + // If there are vectorized iterations, perform them with vector + // instructions. + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + avec[12].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[13].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + avec[14].v = _mm256_loadu_pd(av[2] + 3 * n_elem_per_reg); + avec[15].v = _mm256_loadu_pd(av[3] + 3 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[12].v, xv[3].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[13].v, xv[3].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[14].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[15].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + avec[8].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[9].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[10].v = _mm256_loadu_pd(av[2] + 2 * n_elem_per_reg); + avec[11].v = _mm256_loadu_pd(av[3] + 2 * n_elem_per_reg); + + rhov[0].v = _mm256_fmadd_pd(avec[8].v, xv[2].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[9].v, xv[2].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[10].v, xv[2].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[11].v, xv[2].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + } + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + avec[4].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[2] + 1 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[3] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[1].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[1].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[1].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[1].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + } + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[2] + 0 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[3] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[0].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + av[2] += n_elem_per_reg * n_iter_unroll[3]; + av[3] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + rhov[2].v = _mm256_hadd_pd(rhov[2].v, rhov[2].v); + rhov[3].v = _mm256_hadd_pd(rhov[3].v, rhov[3].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + rho2 = rhov[2].d[0] + rhov[2].d[2]; + rho3 = rhov[3].d[0] + rhov[3].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; } } + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + double *restrict a2 = a + 2 * lda; + double *restrict a3 = a + 3 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + const double a2c = *a2; + const double a3c = *a3; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + rho2 += a2c * x0c; + rho3 += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + // Now prepare the final rho values to output/accumulate back into // the y vector. - v4df_t rho0v, rho1v, y0v, y1v; + v4df_t rho0v, y0v; // Insert the scalar rho values into a single vector. rho0v.d[0] = rho0; rho0v.d[1] = rho1; rho0v.d[2] = rho2; rho0v.d[3] = rho3; - rho1v.d[0] = rho4; - rho1v.d[1] = rho5; - rho1v.d[2] = rho6; - rho1v.d[3] = rho7; // Broadcast the alpha scalar. - v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha ); + v4df_t alphav; + alphav.v = _mm256_broadcast_sd(alpha); // We know at this point that alpha is nonzero; however, beta may still // be zero. If beta is indeed zero, we must overwrite y rather than scale // by beta (in case y contains NaN or Inf). - if ( PASTEMAC(d,eq0)( *beta ) ) + if (PASTEMAC(d, eq0)(*beta)) { // Apply alpha to the accumulated dot product in rho: // y := alpha * rho - y0v.v = _mm256_mul_pd( alphav.v, rho0v.v ); - y1v.v = _mm256_mul_pd( alphav.v, rho1v.v ); + y0v.v = _mm256_mul_pd(alphav.v, rho0v.v); } else { // Broadcast the beta scalar. - v4df_t betav; betav.v = _mm256_broadcast_sd( beta ); + v4df_t betav; + betav.v = _mm256_broadcast_sd(beta); // Load y. - if ( incy == 1 ) + if (incy == 1) { - y0v.v = _mm256_loadu_pd( y + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + y0v.v = _mm256_loadu_pd(y + 0 * n_elem_per_reg); } else { - y0v.d[0] = *(y + 0*incy); y0v.d[1] = *(y + 1*incy); - y0v.d[2] = *(y + 2*incy); y0v.d[3] = *(y + 3*incy); - y1v.d[0] = *(y + 4*incy); y1v.d[1] = *(y + 5*incy); - y1v.d[2] = *(y + 6*incy); y1v.d[3] = *(y + 7*incy); + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + y0v.d[2] = *(y + 2 * incy); + y0v.d[3] = *(y + 3 * incy); } // Apply beta to y and alpha to the accumulated dot product in rho: // y := beta * y + alpha * rho - y0v.v = _mm256_mul_pd( betav.v, y0v.v ); - y1v.v = _mm256_mul_pd( betav.v, y1v.v ); - y0v.v = _mm256_fmadd_pd( alphav.v, rho0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( alphav.v, rho1v.v, y1v.v ); + y0v.v = _mm256_mul_pd(betav.v, y0v.v); + y0v.v = _mm256_fmadd_pd(alphav.v, rho0v.v, y0v.v); } - if ( incy == 1 ) + if (incy == 1) { // Store the output. - _mm256_storeu_pd( (y + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (y + 1*n_elem_per_reg), y1v.v ); + _mm256_storeu_pd((y + 0 * n_elem_per_reg), y0v.v); } else { - *(y + 0*incy) = y0v.d[0]; *(y + 1*incy) = y0v.d[1]; - *(y + 2*incy) = y0v.d[2]; *(y + 3*incy) = y0v.d[3]; - *(y + 4*incy) = y1v.d[0]; *(y + 5*incy) = y1v.d[1]; - *(y + 6*incy) = y1v.d[2]; *(y + 7*incy) = y1v.d[3]; + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + *(y + 2 * incy) = y0v.d[2]; + *(y + 3 * incy) = y0v.d[3]; } } +void bli_ddotxf_zen_int_2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + double *restrict alpha, + double *restrict a, inc_t inca, inc_t lda, + double *restrict x, inc_t incx, + double *restrict beta, + double *restrict y, inc_t incy, + cntx_t *restrict cntx + ) +{ + const dim_t fuse_fac = 2; + const dim_t n_elem_per_reg = 4; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(d, eq0)(*alpha)) + { + bli_dscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n != fuse_fac) + { + for (dim_t i = 0; i < b_n; ++i) + { + double *a1 = a + (0) * inca + (i)*lda; + double *x1 = x + (0) * incx; + double *psi1 = y + (i)*incy; + + bli_ddotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // At this point, we know that b_n is exactly equal to the fusing factor. + // However, m may not be a multiple of the number of elements per vector. + + // Going forward, we handle two possible storage formats of A explicitly: + // (1) A is stored by columns, or (2) A is stored by rows. Either case is + // further split into two subproblems along the m dimension: + // (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m. + // (b) a scalar part, starting at m' and ending at m. If no vectorization + // is possible then m' == 0 and thus the scalar part is the entire + // problem. If 0 < m', then the a and x pointers and m variable will + // be adjusted accordingly for the second subproblem. + // Note: since parts (b) for both (1) and (2) are so similar, they are + // factored out into one code block after the following conditional, which + // distinguishes between (1) and (2). + + // Intermediate variables to hold the completed dot products + double rho0 = 0, rho1 = 0; + + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll[4] = {8, 4, 2, 1}; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t m_viter[4], i, m_left = m; + + for (i = 0; i < 4; ++i) + { + m_viter[i] = (m_left) / (n_elem_per_reg * n_iter_unroll[i]); + m_left = (m_left) % (n_elem_per_reg * n_iter_unroll[i]); + } + + // Set up pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict av[2]; + + av[0] = a + 0 * lda; + av[1] = a + 1 * lda; + + // Initialize b_n rho vector accumulators to zero. + v4df_t rhov[8]; + + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + v4df_t xv[4]; + v4df_t avec[8]; + + for (i = 0; i < m_viter[0]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + } + + for (i = 0; i < m_viter[1]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + rhov[4].v = _mm256_fmadd_pd(avec[4].v, xv[2].v, rhov[4].v); + rhov[5].v = _mm256_fmadd_pd(avec[5].v, xv[2].v, rhov[5].v); + rhov[6].v = _mm256_fmadd_pd(avec[6].v, xv[3].v, rhov[6].v); + rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); + + x0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_add_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_add_pd(rhov[3].v, rhov[7].v); + + for (i = 0; i < m_viter[2]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + rhov[2].v = _mm256_fmadd_pd(avec[2].v, xv[1].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_pd(avec[3].v, xv[1].v, rhov[3].v); + + x0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + } + + rhov[0].v = _mm256_add_pd(rhov[0].v, rhov[2].v); + rhov[1].v = _mm256_add_pd(rhov[1].v, rhov[3].v); + + for (i = 0; i < m_viter[3]; ++i) + { + // Load the input values. + xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_pd(avec[1].v, xv[0].v, rhov[1].v); + + x0 += n_elem_per_reg * n_iter_unroll[3]; + av[0] += n_elem_per_reg * n_iter_unroll[3]; + av[1] += n_elem_per_reg * n_iter_unroll[3]; + } + + // Sum the elements of a given rho?v. This computes the sum of + // elements within lanes and stores the sum to both elements. + rhov[0].v = _mm256_hadd_pd(rhov[0].v, rhov[0].v); + rhov[1].v = _mm256_hadd_pd(rhov[1].v, rhov[1].v); + + // Manually add the results from above to finish the sum. + rho0 = rhov[0].d[0] + rhov[0].d[2]; + rho1 = rhov[1].d[0] + rhov[1].d[2]; + + // Adjust for scalar subproblem. + for (i = 0; i < 4; ++i) + { + m -= n_elem_per_reg * n_iter_unroll[i] * m_viter[i]; + a += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * inca */; + x += n_elem_per_reg * n_iter_unroll[i] * m_viter[i] /* * incx */; + } + } + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + double *restrict x0 = x; + double *restrict a0 = a + 0 * lda; + double *restrict a1 = a + 1 * lda; + + // If there are leftover iterations, perform them with scalar code. + for (dim_t i = 0; i < m; ++i) + { + const double x0c = *x0; + + const double a0c = *a0; + const double a1c = *a1; + + rho0 += a0c * x0c; + rho1 += a1c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + } + + // Now prepare the final rho values to output/accumulate back into + // the y vector. + + v2df_t rho0v, y0v; + + // Insert the scalar rho values into a single vector. + rho0v.d[0] = rho0; + rho0v.d[1] = rho1; + + // Broadcast the alpha scalar. + v2df_t alphav; + + alphav.v = _mm_load1_pd(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(d, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_pd(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v2df_t betav; + betav.v = _mm_load1_pd(beta); + + // Load y. + if (incy == 1) + { + y0v.v = _mm_loadu_pd(y + 0 * 2); + } + else + { + y0v.d[0] = *(y + 0 * incy); + y0v.d[1] = *(y + 1 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_pd(betav.v, y0v.v); + y0v.v = _mm_fmadd_pd(alphav.v, rho0v.v, y0v.v); + } + + if (incy == 1) + { + // Store the output. + _mm_storeu_pd((y + 0 * 2), y0v.v); + } + else + { + *(y + 0 * incy) = y0v.d[0]; + *(y + 1 * incy) = y0v.d[1]; + } +} + + diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 3bdbbed4e5..e8cbe49d15 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -116,6 +116,8 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_4 ) +DOTXF_KER_PROT( double, d, dotxf_zen_int_2 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) From 3f8961d7cddd9c3d842e8dde5143ea4a33b9c4ad Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Wed, 15 Dec 2021 15:11:08 +0530 Subject: [PATCH 07/63] Reduced number of threads in dgemm for small dimensions - Number of threads are reduced to 1 when the dimensions are very low. - Removed uninitialized xmm compilation warning in trsm small Change-Id: I23262fb82729af5b98ded5d36f5eed45d5255d5b --- frame/base/bli_rntm.c | 4 ++++ kernels/zen/3/bli_trsm_small.c | 3 +++ 2 files changed, 7 insertions(+) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 6a100bbe8e..dc0acf6bf9 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -574,6 +574,10 @@ void bli_nthreads_optimum( if(n < 15) n_threads_ideal = 1; else n_threads_ideal = 4; } + else if( ( m < 34) && (k < 68) && ( m < 34)) + { + n_threads_ideal = 1; + } else { if(n < 20) n_threads_ideal = 1; diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index c782a08a49..0fa8f66d5a 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -2847,6 +2847,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ @@ -3009,6 +3010,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3);\ @@ -3116,6 +3118,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref #define BLIS_PRE_STRSM_SMALL_1N_2M(AlphaVal,b11,cs_b)\ ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ \ + xmm5 = _mm_setzero_ps();\ xmm5 = _mm_loadl_pi(xmm5,(__m64*)(b11));\ ymm6 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ ymm3 = _mm256_fmsub_ps(ymm6, ymm15, ymm3); From bd148ae8b86fe26dc0a2c5360e8952738155458c Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Fri, 17 Dec 2021 14:04:13 +0530 Subject: [PATCH 08/63] Fixed DDOTXF Bug - Corrected xv and avec indexing in vector loop of bli_ddotxf_zen_int_2 Change-Id: I4c511236aad09541fe6b1295103a1a8b54ceec39 --- kernels/zen/1f/bli_dotxf_zen_int_8.c | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index e25910fb4e..ad27403bdc 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -1337,19 +1337,19 @@ void bli_ddotxf_zen_int_2 rhov[7].v = _mm256_fmadd_pd(avec[7].v, xv[3].v, rhov[7].v); // Load the input values. - xv[0].v = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); - xv[1].v = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); - xv[2].v = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); - xv[3].v = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); - - avec[0].v = _mm256_loadu_pd(av[0] + 0 * n_elem_per_reg); - avec[1].v = _mm256_loadu_pd(av[1] + 0 * n_elem_per_reg); - avec[2].v = _mm256_loadu_pd(av[0] + 1 * n_elem_per_reg); - avec[3].v = _mm256_loadu_pd(av[1] + 1 * n_elem_per_reg); - avec[4].v = _mm256_loadu_pd(av[0] + 2 * n_elem_per_reg); - avec[5].v = _mm256_loadu_pd(av[1] + 2 * n_elem_per_reg); - avec[6].v = _mm256_loadu_pd(av[0] + 3 * n_elem_per_reg); - avec[7].v = _mm256_loadu_pd(av[1] + 3 * n_elem_per_reg); + xv[0].v = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[1].v = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[2].v = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[3].v = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); + + avec[0].v = _mm256_loadu_pd(av[0] + 4 * n_elem_per_reg); + avec[1].v = _mm256_loadu_pd(av[1] + 4 * n_elem_per_reg); + avec[2].v = _mm256_loadu_pd(av[0] + 5 * n_elem_per_reg); + avec[3].v = _mm256_loadu_pd(av[1] + 5 * n_elem_per_reg); + avec[4].v = _mm256_loadu_pd(av[0] + 6 * n_elem_per_reg); + avec[5].v = _mm256_loadu_pd(av[1] + 6 * n_elem_per_reg); + avec[6].v = _mm256_loadu_pd(av[0] + 7 * n_elem_per_reg); + avec[7].v = _mm256_loadu_pd(av[1] + 7 * n_elem_per_reg); // perform: rho?v += a?v * x0v; rhov[0].v = _mm256_fmadd_pd(avec[0].v, xv[0].v, rhov[0].v); From 389a734c1075c061372218177dfbb6dbe91f5510 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Mon, 20 Dec 2021 12:17:05 +0530 Subject: [PATCH 09/63] Improved AXPYV Kernel performance - Increased the unroll factor of the loop by 15 in SAXPYV - Increased the unroll factor of the loop by 12 in DAXPYV - The above changes were made for better register utilization Change-Id: I69ad1fec2fcf958dbd1bfd71378641274b43a6aa --- kernels/zen/1/bli_axpyv_zen_int10.c | 150 ++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 8 deletions(-) diff --git a/kernels/zen/1/bli_axpyv_zen_int10.c b/kernels/zen/1/bli_axpyv_zen_int10.c index 6f953e6f4c..4ef6981cd7 100644 --- a/kernels/zen/1/bli_axpyv_zen_int10.c +++ b/kernels/zen/1/bli_axpyv_zen_int10.c @@ -75,9 +75,9 @@ void bli_saxpyv_zen_int10 float* restrict y0; __m256 alphav; - __m256 xv[10]; - __m256 yv[10]; - __m256 zv[10]; + __m256 xv[15]; + __m256 yv[15]; + __m256 zv[15]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq0)( *alpha ) ) @@ -95,7 +95,78 @@ void bli_saxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_ss( alpha ); - for ( i = 0; (i + 79) < n; i += 80 ) + for (i = 0; (i + 119) < n; i += 120) + { + // 120 elements will be processed per loop; 15 FMAs will run per loop. + xv[0] = _mm256_loadu_ps(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_ps(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_ps(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_ps(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_ps(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_ps(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_ps(x0 + 6 * n_elem_per_reg); + xv[7] = _mm256_loadu_ps(x0 + 7 * n_elem_per_reg); + xv[8] = _mm256_loadu_ps(x0 + 8 * n_elem_per_reg); + xv[9] = _mm256_loadu_ps(x0 + 9 * n_elem_per_reg); + xv[10] = _mm256_loadu_ps(x0 + 10 * n_elem_per_reg); + xv[11] = _mm256_loadu_ps(x0 + 11 * n_elem_per_reg); + xv[12] = _mm256_loadu_ps(x0 + 12 * n_elem_per_reg); + xv[13] = _mm256_loadu_ps(x0 + 13 * n_elem_per_reg); + xv[14] = _mm256_loadu_ps(x0 + 14 * n_elem_per_reg); + + yv[0] = _mm256_loadu_ps(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_ps(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_ps(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_ps(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_ps(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_ps(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_ps(y0 + 6 * n_elem_per_reg); + yv[7] = _mm256_loadu_ps(y0 + 7 * n_elem_per_reg); + yv[8] = _mm256_loadu_ps(y0 + 8 * n_elem_per_reg); + yv[9] = _mm256_loadu_ps(y0 + 9 * n_elem_per_reg); + yv[10] = _mm256_loadu_ps(y0 + 10 * n_elem_per_reg); + yv[11] = _mm256_loadu_ps(y0 + 11 * n_elem_per_reg); + yv[12] = _mm256_loadu_ps(y0 + 12 * n_elem_per_reg); + yv[13] = _mm256_loadu_ps(y0 + 13 * n_elem_per_reg); + yv[14] = _mm256_loadu_ps(y0 + 14 * n_elem_per_reg); + + zv[0] = _mm256_fmadd_ps(xv[0], alphav, yv[0]); + zv[1] = _mm256_fmadd_ps(xv[1], alphav, yv[1]); + zv[2] = _mm256_fmadd_ps(xv[2], alphav, yv[2]); + zv[3] = _mm256_fmadd_ps(xv[3], alphav, yv[3]); + zv[4] = _mm256_fmadd_ps(xv[4], alphav, yv[4]); + zv[5] = _mm256_fmadd_ps(xv[5], alphav, yv[5]); + zv[6] = _mm256_fmadd_ps(xv[6], alphav, yv[6]); + zv[7] = _mm256_fmadd_ps(xv[7], alphav, yv[7]); + zv[8] = _mm256_fmadd_ps(xv[8], alphav, yv[8]); + zv[9] = _mm256_fmadd_ps(xv[9], alphav, yv[9]); + zv[10] = _mm256_fmadd_ps(xv[10], alphav, yv[10]); + zv[11] = _mm256_fmadd_ps(xv[11], alphav, yv[11]); + zv[12] = _mm256_fmadd_ps(xv[12], alphav, yv[12]); + zv[13] = _mm256_fmadd_ps(xv[13], alphav, yv[13]); + zv[14] = _mm256_fmadd_ps(xv[14], alphav, yv[14]); + + _mm256_storeu_ps((y0 + 0 * n_elem_per_reg), zv[0]); + _mm256_storeu_ps((y0 + 1 * n_elem_per_reg), zv[1]); + _mm256_storeu_ps((y0 + 2 * n_elem_per_reg), zv[2]); + _mm256_storeu_ps((y0 + 3 * n_elem_per_reg), zv[3]); + _mm256_storeu_ps((y0 + 4 * n_elem_per_reg), zv[4]); + _mm256_storeu_ps((y0 + 5 * n_elem_per_reg), zv[5]); + _mm256_storeu_ps((y0 + 6 * n_elem_per_reg), zv[6]); + _mm256_storeu_ps((y0 + 7 * n_elem_per_reg), zv[7]); + _mm256_storeu_ps((y0 + 8 * n_elem_per_reg), zv[8]); + _mm256_storeu_ps((y0 + 9 * n_elem_per_reg), zv[9]); + _mm256_storeu_ps((y0 + 10 * n_elem_per_reg), zv[10]); + _mm256_storeu_ps((y0 + 11 * n_elem_per_reg), zv[11]); + _mm256_storeu_ps((y0 + 12 * n_elem_per_reg), zv[12]); + _mm256_storeu_ps((y0 + 13 * n_elem_per_reg), zv[13]); + _mm256_storeu_ps((y0 + 14 * n_elem_per_reg), zv[14]); + + x0 += 15 * n_elem_per_reg; + y0 += 15 * n_elem_per_reg; + } + + for (; (i + 79) < n; i += 80 ) { // 80 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); @@ -288,9 +359,9 @@ void bli_daxpyv_zen_int10 double* restrict y0 = y; __m256d alphav; - __m256d xv[10]; - __m256d yv[10]; - __m256d zv[10]; + __m256d xv[13]; + __m256d yv[13]; + __m256d zv[13]; // If the vector dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq0)( *alpha ) ) @@ -308,7 +379,70 @@ void bli_daxpyv_zen_int10 // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); - for ( i = 0; (i + 39) < n; i += 40 ) + for (i = 0; (i + 51) < n; i += 52) + { + // 52 elements will be processed per loop; 13 FMAs will run per loop. + xv[0] = _mm256_loadu_pd(x0 + 0 * n_elem_per_reg); + xv[1] = _mm256_loadu_pd(x0 + 1 * n_elem_per_reg); + xv[2] = _mm256_loadu_pd(x0 + 2 * n_elem_per_reg); + xv[3] = _mm256_loadu_pd(x0 + 3 * n_elem_per_reg); + xv[4] = _mm256_loadu_pd(x0 + 4 * n_elem_per_reg); + xv[5] = _mm256_loadu_pd(x0 + 5 * n_elem_per_reg); + xv[6] = _mm256_loadu_pd(x0 + 6 * n_elem_per_reg); + xv[7] = _mm256_loadu_pd(x0 + 7 * n_elem_per_reg); + xv[8] = _mm256_loadu_pd(x0 + 8 * n_elem_per_reg); + xv[9] = _mm256_loadu_pd(x0 + 9 * n_elem_per_reg); + xv[10] = _mm256_loadu_pd(x0 + 10 * n_elem_per_reg); + xv[11] = _mm256_loadu_pd(x0 + 11 * n_elem_per_reg); + xv[12] = _mm256_loadu_pd(x0 + 12 * n_elem_per_reg); + + yv[0] = _mm256_loadu_pd(y0 + 0 * n_elem_per_reg); + yv[1] = _mm256_loadu_pd(y0 + 1 * n_elem_per_reg); + yv[2] = _mm256_loadu_pd(y0 + 2 * n_elem_per_reg); + yv[3] = _mm256_loadu_pd(y0 + 3 * n_elem_per_reg); + yv[4] = _mm256_loadu_pd(y0 + 4 * n_elem_per_reg); + yv[5] = _mm256_loadu_pd(y0 + 5 * n_elem_per_reg); + yv[6] = _mm256_loadu_pd(y0 + 6 * n_elem_per_reg); + yv[7] = _mm256_loadu_pd(y0 + 7 * n_elem_per_reg); + yv[8] = _mm256_loadu_pd(y0 + 8 * n_elem_per_reg); + yv[9] = _mm256_loadu_pd(y0 + 9 * n_elem_per_reg); + yv[10] = _mm256_loadu_pd(y0 + 10 * n_elem_per_reg); + yv[11] = _mm256_loadu_pd(y0 + 11 * n_elem_per_reg); + yv[12] = _mm256_loadu_pd(y0 + 12 * n_elem_per_reg); + + zv[0] = _mm256_fmadd_pd(xv[0], alphav, yv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphav, yv[1]); + zv[2] = _mm256_fmadd_pd(xv[2], alphav, yv[2]); + zv[3] = _mm256_fmadd_pd(xv[3], alphav, yv[3]); + zv[4] = _mm256_fmadd_pd(xv[4], alphav, yv[4]); + zv[5] = _mm256_fmadd_pd(xv[5], alphav, yv[5]); + zv[6] = _mm256_fmadd_pd(xv[6], alphav, yv[6]); + zv[7] = _mm256_fmadd_pd(xv[7], alphav, yv[7]); + zv[8] = _mm256_fmadd_pd(xv[8], alphav, yv[8]); + zv[9] = _mm256_fmadd_pd(xv[9], alphav, yv[9]); + zv[10] = _mm256_fmadd_pd(xv[10], alphav, yv[10]); + zv[11] = _mm256_fmadd_pd(xv[11], alphav, yv[11]); + zv[12] = _mm256_fmadd_pd(xv[12], alphav, yv[12]); + + _mm256_storeu_pd((y0 + 0 * n_elem_per_reg), zv[0]); + _mm256_storeu_pd((y0 + 1 * n_elem_per_reg), zv[1]); + _mm256_storeu_pd((y0 + 2 * n_elem_per_reg), zv[2]); + _mm256_storeu_pd((y0 + 3 * n_elem_per_reg), zv[3]); + _mm256_storeu_pd((y0 + 4 * n_elem_per_reg), zv[4]); + _mm256_storeu_pd((y0 + 5 * n_elem_per_reg), zv[5]); + _mm256_storeu_pd((y0 + 6 * n_elem_per_reg), zv[6]); + _mm256_storeu_pd((y0 + 7 * n_elem_per_reg), zv[7]); + _mm256_storeu_pd((y0 + 8 * n_elem_per_reg), zv[8]); + _mm256_storeu_pd((y0 + 9 * n_elem_per_reg), zv[9]); + _mm256_storeu_pd((y0 + 10 * n_elem_per_reg), zv[10]); + _mm256_storeu_pd((y0 + 11 * n_elem_per_reg), zv[11]); + _mm256_storeu_pd((y0 + 12 * n_elem_per_reg), zv[12]); + + x0 += 13 * n_elem_per_reg; + y0 += 13 * n_elem_per_reg; + } + + for ( ; (i + 39) < n; i += 40 ) { // 40 elements will be processed per loop; 10 FMAs will run per loop. xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); From c81be92dca49209a007e2bc0f8b6a9b104cb4007 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 22 Dec 2021 14:47:15 +0530 Subject: [PATCH 10/63] AOCL-Windows: Updating the blis windows build system. 1. Removed the libomp.lib hardcoded from cmake scripts and made it user configurable. By default libomp.lib is used as an omp library. 2. Added the STATIC_LIBRARY_OPTIONS property in set_target_properties cmake command to link omp library to build static-mt blis library. 3. Updated the blis_ref_kernel_mirror.py to give support for zen4 architecture. AMD-Internal: CPUPL-1630 Change-Id: I54b04cde2fa6a1ddc4b4303f1da808c1efe0484a --- CMakeLists.txt | 13 +++-- build/blis_ref_kernel_mirror.py | 31 +++++++---- test/CMakeLists.txt | 96 ++++++++++++++++----------------- testsuite/CMakeLists.txt | 4 +- 4 files changed, 76 insertions(+), 68 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index d6885e3a38..e2cb3818e6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,8 +10,7 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OPENMP_PATH "C:\\Program Files\\LLVM\\lib" CACHE STRING "openmp library -path") +SET(OMP_LIB "C:\\Program Files\\LLVM\\lib\\libomp.lib" CACHE STRING "openmp library path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") @@ -532,15 +531,14 @@ file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") -message( STATUS "OPENMP PATH:" ${OPENMP_PATH}) -link_directories("${OPENMP_PATH}") +message( STATUS "OPENMP Library:" ${OMP_LIB}) if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OPENMP_PATH}/libomp.lib") + target_link_libraries("${PROJECT_NAME}" PUBLIC "${OMP_LIB}") endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -550,9 +548,10 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OPENMP_PATH}/libomp.lib") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OMP_LIB}") + else() + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") endif() link_directories(${CMAKE_LIBRARY_OUTPUT_DIRECTORY}) diff --git a/build/blis_ref_kernel_mirror.py b/build/blis_ref_kernel_mirror.py index b756eb30b6..8ef90a12af 100644 --- a/build/blis_ref_kernel_mirror.py +++ b/build/blis_ref_kernel_mirror.py @@ -68,11 +68,13 @@ def remove_lines_in_file(filename): with open(filename, 'r') as fd: file_content = fd.read() file_content = file_content.replace( - 'if(${TARGET_ARCH} STREQUAL amdzen)\nadd_subdirectory(${CMAKE_BINARY_' - 'DIR}/ref_kernels/generic ${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' - 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ${CMAKE_BINARY_' - 'DIR}/ref_kernels/zen)\nadd_subdirectory(${CMAKE_BINARY_DIR}/' - 'ref_kernels/zen2 ${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' + 'if(${TARGET_ARCH} STREQUAL amdzen)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/generic ' + '${CMAKE_BINARY_DIR}/ref_kernels/generic)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen)\n' + 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen2 ' + '${CMAKE_BINARY_DIR}/ref_kernels/zen2)\n' 'add_subdirectory(${CMAKE_BINARY_DIR}/ref_kernels/zen3 ' '${CMAKE_BINARY_DIR}/ref_kernels/zen3)\nelse()', '\n') data = file_content.replace('endif()', '\n') @@ -111,6 +113,7 @@ def add_macro_to_cfiles(cfiles, macro): create_folder(os.path.join(dest_path, 'zen')) create_folder(os.path.join(dest_path, 'zen2')) create_folder(os.path.join(dest_path, 'zen3')) + create_folder(os.path.join(dest_path, 'zen4')) create_folder(os.path.join(dest_path, 'generic')) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen'))) @@ -118,6 +121,8 @@ def add_macro_to_cfiles(cfiles, macro): temp, os.path.join(dest_path, 'zen2'))) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'zen3'))) + execute_and_check('XCOPY {} {} /E'.format( + temp, os.path.join(dest_path, 'zen4'))) execute_and_check('XCOPY {} {} /E'.format( temp, os.path.join(dest_path, 'generic'))) remove_folder(temp) @@ -129,6 +134,8 @@ def add_macro_to_cfiles(cfiles, macro): dest_path, 'zen2', 'CMakeLists.txt')) remove_lines_in_file(os.path.join( dest_path, 'zen3', 'CMakeLists.txt')) + remove_lines_in_file(os.path.join( + dest_path, 'zen4', 'CMakeLists.txt')) cfiles_in_generic = execute_and_check('cd {} && dir / s / b / o: gn *.c' .format(os.path.join(dest_path, 'generic'))) @@ -136,20 +143,22 @@ def add_macro_to_cfiles(cfiles, macro): add_macro_to_cfiles(cfiles_in_generic, '\n#define BLIS_CNAME_INFIX _generic\n') cfiles_in_zen = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen'))) + .format(os.path.join(dest_path, 'zen'))) cfiles_in_zen = cfiles_in_zen.split('\r\n') add_macro_to_cfiles(cfiles_in_zen, '\n#define BLIS_CNAME_INFIX _zen\n') cfiles_in_zen2 = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen2'))) + .format(os.path.join(dest_path, 'zen2'))) cfiles_in_zen2 = cfiles_in_zen2.split('\r\n') add_macro_to_cfiles(cfiles_in_zen2, '\n#define BLIS_CNAME_INFIX _zen2\n') cfiles_in_zen3 = execute_and_check('cd {} && dir / s / b / o: gn *.c' - .format(os.path.join(dest_path, - 'zen3'))) + .format(os.path.join(dest_path, 'zen3'))) cfiles_in_zen3 = cfiles_in_zen3.split('\r\n') add_macro_to_cfiles(cfiles_in_zen3, '\n#define BLIS_CNAME_INFIX _zen3\n') + cfiles_in_zen4 = execute_and_check('cd {} && dir / s / b / o: gn *.c' + .format(os.path.join(dest_path, 'zen4'))) + cfiles_in_zen4 = cfiles_in_zen4.split('\r\n') + add_macro_to_cfiles(cfiles_in_zen4, + '\n#define BLIS_CNAME_INFIX _zen4\n') diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 3b0315c9ae..fe8f7bac98 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -4,169 +4,169 @@ add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAminv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAminv "${OMP_LIB}") endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpyv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpyv "${OMP_LIB}") endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestAxpbyv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestAxpbyv "${OMP_LIB}") endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCopyv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCopyv "${OMP_LIB}") endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestCabs1 "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestCabs1 "${OMP_LIB}") endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestDotv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestDotv "${OMP_LIB}") endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm "${OMP_LIB}") endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmBatch "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmBatch "${OMP_LIB}") endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemm3m "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemm3m "${OMP_LIB}") endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemmt "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemmt "${OMP_LIB}") endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGemv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGemv "${OMP_LIB}") endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestGer "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestGer "${OMP_LIB}") endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemm "${OMP_LIB}") endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHemv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHemv "${OMP_LIB}") endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer "${OMP_LIB}") endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2 "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2 "${OMP_LIB}") endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHer2k "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHer2k "${OMP_LIB}") endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestHerk "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestHerk "${OMP_LIB}") endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestScalv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestScalv "${OMP_LIB}") endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestSwapv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestSwapv "${OMP_LIB}") endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmm "${OMP_LIB}") endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrmv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrmv "${OMP_LIB}") endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsm "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsm "${OMP_LIB}") endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(TestTrsv "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(TestTrsv "${OMP_LIB}") endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index f03d094782..613f9e3861 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -7,8 +7,8 @@ add_executable(test_libblis "") add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP) - target_link_libraries(test_libblis "${OPENMP_PATH}/libomp.lib") +if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) + target_link_libraries(test_libblis "${OMP_LIB}") endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") From 45d2264e4bd74caca63df520e72b66a070eb4480 Mon Sep 17 00:00:00 2001 From: HariharaSudhan S Date: Fri, 24 Dec 2021 00:05:13 -0500 Subject: [PATCH 11/63] Merge "Improved AXPYV Kernel performance" into amd-staging-genoa-4.0 From a48c6da457709359522e9f438d7420eedb768025 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Mon, 15 Nov 2021 23:28:33 +0530 Subject: [PATCH 12/63] Improved SCALV kernel performance. - Unrolled the loop by a greater factor. Incorporated switch case to decide unrolling factor according to the input size. - Removed unused structs. AMD-Internal: [CPUPL-1974] Change-Id: Iee9d7defcc8c582ca0420f84c4fb2c202dabe3e7 --- kernels/zen/1/bli_scalv_zen_int10.c | 666 +++++++++++++++++----------- 1 file changed, 404 insertions(+), 262 deletions(-) diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index 6c7f52e161..de9d8339d3 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -36,23 +36,6 @@ #include "immintrin.h" #include "blis.h" - -/* Union data structure to access AVX registers - One 256-bit AVX register holds 8 SP elements. */ -typedef union -{ - __m256 v; - float f[8] __attribute__((aligned(64))); -} v8sf_t; - -/* Union data structure to access AVX registers -* One 256-bit AVX register holds 4 DP elements. */ -typedef union -{ - __m256d v; - double d[4] __attribute__((aligned(64))); -} v4df_t; - // ----------------------------------------------------------------------------- void bli_sscalv_zen_int10 @@ -66,13 +49,13 @@ void bli_sscalv_zen_int10 { const dim_t n_elem_per_reg = 8; - dim_t i; + dim_t i = 0; float* restrict x0; __m256 alphav; - __m256 xv[10]; - __m256 zv[10]; + __m256 xv[16]; + __m256 zv[16]; // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(s,eq1)( *alpha ) ) return; @@ -111,140 +94,218 @@ void bli_sscalv_zen_int10 { // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_ss( alpha ); + dim_t option; - for ( i = 0; (i + 79) < n; i += 80 ) + // Unroll and the loop used is picked based on the input size. + if( n < 300) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - zv[4] = _mm256_mul_ps( alphav, xv[4] ); - zv[5] = _mm256_mul_ps( alphav, xv[5] ); - zv[6] = _mm256_mul_ps( alphav, xv[6] ); - zv[7] = _mm256_mul_ps( alphav, xv[7] ); - zv[8] = _mm256_mul_ps( alphav, xv[8] ); - zv[9] = _mm256_mul_ps( alphav, xv[9] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; + option = 2; } - - for ( ; (i + 39) < n; i += 40 ) - { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - zv[4] = _mm256_mul_ps( alphav, xv[4] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; - } - - for ( ; (i + 31) < n; i += 32 ) + else if( n < 500) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - zv[2] = _mm256_mul_ps( alphav, xv[2] ); - zv[3] = _mm256_mul_ps( alphav, xv[3] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; + option = 1; } - - for ( ; (i + 15) < n; i += 16 ) + else { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - zv[1] = _mm256_mul_ps( alphav, xv[1] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; + option = 0; } - for ( ; (i + 7) < n; i += 8 ) + switch(option) { - // Load the input values. - xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_ps( alphav, xv[0] ); - - // Store the output. - _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - *x0 *= *alpha; - - x0 += 1; + case 0: + + for ( ; (i + 127) < n; i += 128 ) + { + //Load the input values + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // Perform : x := alpha * x; + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + + // Store the result + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + zv[6] = _mm256_mul_ps( alphav, xv[6] ); + zv[7] = _mm256_mul_ps( alphav, xv[7] ); + + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_ps( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_ps( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_ps( alphav, xv[8] ); + zv[9] = _mm256_mul_ps( alphav, xv[9] ); + zv[10] = _mm256_mul_ps( alphav, xv[10] ); + zv[11] = _mm256_mul_ps( alphav, xv[11] ); + + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_ps( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_ps( (x0 + 11*n_elem_per_reg), zv[11] ); + + xv[12] = _mm256_loadu_ps( x0 + 12*n_elem_per_reg ); + xv[13] = _mm256_loadu_ps( x0 + 13*n_elem_per_reg ); + xv[14] = _mm256_loadu_ps( x0 + 14*n_elem_per_reg ); + xv[15] = _mm256_loadu_ps( x0 + 15*n_elem_per_reg ); + + zv[12] = _mm256_mul_ps( alphav, xv[12] ); + zv[13] = _mm256_mul_ps( alphav, xv[13] ); + zv[14] = _mm256_mul_ps( alphav, xv[14] ); + zv[15] = _mm256_mul_ps( alphav, xv[15] ); + + _mm256_storeu_ps( (x0 + 12*n_elem_per_reg), zv[12] ); + _mm256_storeu_ps( (x0 + 13*n_elem_per_reg), zv[13] ); + _mm256_storeu_ps( (x0 + 14*n_elem_per_reg), zv[14] ); + _mm256_storeu_ps( (x0 + 15*n_elem_per_reg), zv[15] ); + + x0 += 16*n_elem_per_reg; + } + + case 1 : + + for ( ; (i + 95) < n; i += 96 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_ps( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_ps( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + zv[6] = _mm256_mul_ps( alphav, xv[6] ); + zv[7] = _mm256_mul_ps( alphav, xv[7] ); + + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_ps( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_ps( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_ps( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_ps( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_ps( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_ps( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_ps( alphav, xv[8] ); + zv[9] = _mm256_mul_ps( alphav, xv[9] ); + zv[10] = _mm256_mul_ps( alphav, xv[10] ); + zv[11] = _mm256_mul_ps( alphav, xv[11] ); + + _mm256_storeu_ps( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_ps( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_ps( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_ps( (x0 + 11*n_elem_per_reg), zv[11] ); + + x0 += 12*n_elem_per_reg; + } + + case 2: + + for ( ; (i + 47) < n; i += 48 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + xv[4] = _mm256_loadu_ps( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_ps( x0 + 5*n_elem_per_reg ); + + zv[3] = _mm256_mul_ps( alphav, xv[3] ); + zv[4] = _mm256_mul_ps( alphav, xv[4] ); + zv[5] = _mm256_mul_ps( alphav, xv[5] ); + + _mm256_storeu_ps( (x0 + 3*n_elem_per_reg), zv[3] ); + _mm256_storeu_ps( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_ps( (x0 + 5*n_elem_per_reg), zv[5] ); + + x0 += 6*n_elem_per_reg; + } + + for ( ; (i + 23) < n; i += 24 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + zv[1] = _mm256_mul_ps( alphav, xv[1] ); + zv[2] = _mm256_mul_ps( alphav, xv[2] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_ps( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_ps( (x0 + 2*n_elem_per_reg), zv[2] ); + + x0 += 3*n_elem_per_reg; + } + + for ( ; (i + 7) < n; i += 8 ) + { + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_mul_ps( alphav, xv[0] ); + + _mm256_storeu_ps( (x0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + *x0 *= *alpha; + + x0 += 1; + } } } else { const float alphac = *alpha; - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { *x0 *= alphac; @@ -266,13 +327,13 @@ void bli_dscalv_zen_int10 { const dim_t n_elem_per_reg = 4; - dim_t i; + dim_t i = 0; double* restrict x0; __m256d alphav; - __m256d xv[10]; - __m256d zv[10]; + __m256d xv[16]; + __m256d zv[16]; // If the vector dimension is zero, or if alpha is unit, return early. if ( bli_zero_dim1( n ) || PASTEMAC(d,eq1)( *alpha ) ) return; @@ -312,140 +373,221 @@ void bli_dscalv_zen_int10 { // Broadcast the alpha scalar to all elements of a vector register. alphav = _mm256_broadcast_sd( alpha ); + dim_t option; - for ( i = 0; (i + 39) < n; i += 40 ) + // Unroll and the loop used is picked based on the input size. + if(n < 200) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); - xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); - xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); - xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); - xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - zv[4] = _mm256_mul_pd( alphav, xv[4] ); - zv[5] = _mm256_mul_pd( alphav, xv[5] ); - zv[6] = _mm256_mul_pd( alphav, xv[6] ); - zv[7] = _mm256_mul_pd( alphav, xv[7] ); - zv[8] = _mm256_mul_pd( alphav, xv[8] ); - zv[9] = _mm256_mul_pd( alphav, xv[9] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); - _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); - _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); - _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); - _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); - _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); - - x0 += 10*n_elem_per_reg; + option = 2; } - - for ( ; (i + 19) < n; i += 20 ) + else if(n < 500) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - zv[4] = _mm256_mul_pd( alphav, xv[4] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); - - x0 += 5*n_elem_per_reg; + option = 1; } - - for ( ; (i + 15) < n; i += 16 ) + else { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - zv[2] = _mm256_mul_pd( alphav, xv[2] ); - zv[3] = _mm256_mul_pd( alphav, xv[3] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); - _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); - - x0 += 4*n_elem_per_reg; + option = 0; } - for ( ; (i + 7) < n; i += 8 ) + switch(option) { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - zv[1] = _mm256_mul_pd( alphav, xv[1] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); - - x0 += 2*n_elem_per_reg; - } - - for ( ; (i + 3) < n; i += 4 ) - { - // Load the input values. - xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - - // perform : x := alpha * x; - zv[0] = _mm256_mul_pd( alphav, xv[0] ); - - // Store the output. - _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); - - x0 += 1*n_elem_per_reg; - } - - for ( ; (i + 0) < n; i += 1 ) - { - *x0 *= *alpha; - - x0 += 1; + case 0: + + for (; (i + 63) < n; i += 64 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_pd( alphav, xv[8] ); + zv[9] = _mm256_mul_pd( alphav, xv[9] ); + zv[10] = _mm256_mul_pd( alphav, xv[10] ); + zv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), zv[11] ); + + xv[12] = _mm256_loadu_pd( x0 + 12*n_elem_per_reg ); + xv[13] = _mm256_loadu_pd( x0 + 13*n_elem_per_reg ); + xv[14] = _mm256_loadu_pd( x0 + 14*n_elem_per_reg ); + xv[15] = _mm256_loadu_pd( x0 + 15*n_elem_per_reg ); + + zv[12] = _mm256_mul_pd( alphav, xv[12] ); + zv[13] = _mm256_mul_pd( alphav, xv[13] ); + zv[14] = _mm256_mul_pd( alphav, xv[14] ); + zv[15] = _mm256_mul_pd( alphav, xv[15] ); + + _mm256_storeu_pd( (x0 + 12*n_elem_per_reg), zv[12] ); + _mm256_storeu_pd( (x0 + 13*n_elem_per_reg), zv[13] ); + _mm256_storeu_pd( (x0 + 14*n_elem_per_reg), zv[14] ); + _mm256_storeu_pd( (x0 + 15*n_elem_per_reg), zv[15] ); + + x0 += 16*n_elem_per_reg; + } + + for (; (i + 47) < n; i += 48 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + xv[8] = _mm256_loadu_pd( x0 + 8*n_elem_per_reg ); + xv[9] = _mm256_loadu_pd( x0 + 9*n_elem_per_reg ); + xv[10] = _mm256_loadu_pd( x0 + 10*n_elem_per_reg ); + xv[11] = _mm256_loadu_pd( x0 + 11*n_elem_per_reg ); + + zv[8] = _mm256_mul_pd( alphav, xv[8] ); + zv[9] = _mm256_mul_pd( alphav, xv[9] ); + zv[10] = _mm256_mul_pd( alphav, xv[10] ); + zv[11] = _mm256_mul_pd( alphav, xv[11] ); + + _mm256_storeu_pd( (x0 + 8*n_elem_per_reg), zv[8] ); + _mm256_storeu_pd( (x0 + 9*n_elem_per_reg), zv[9] ); + _mm256_storeu_pd( (x0 + 10*n_elem_per_reg), zv[10] ); + _mm256_storeu_pd( (x0 + 11*n_elem_per_reg), zv[11] ); + + x0 += 12*n_elem_per_reg; + } + + case 1: + + for (; (i + 31) < n; i += 32 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + zv[3] = _mm256_mul_pd( alphav, xv[3] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (x0 + 3*n_elem_per_reg), zv[3] ); + + xv[4] = _mm256_loadu_pd( x0 + 4*n_elem_per_reg ); + xv[5] = _mm256_loadu_pd( x0 + 5*n_elem_per_reg ); + xv[6] = _mm256_loadu_pd( x0 + 6*n_elem_per_reg ); + xv[7] = _mm256_loadu_pd( x0 + 7*n_elem_per_reg ); + + zv[4] = _mm256_mul_pd( alphav, xv[4] ); + zv[5] = _mm256_mul_pd( alphav, xv[5] ); + zv[6] = _mm256_mul_pd( alphav, xv[6] ); + zv[7] = _mm256_mul_pd( alphav, xv[7] ); + + _mm256_storeu_pd( (x0 + 4*n_elem_per_reg), zv[4] ); + _mm256_storeu_pd( (x0 + 5*n_elem_per_reg), zv[5] ); + _mm256_storeu_pd( (x0 + 6*n_elem_per_reg), zv[6] ); + _mm256_storeu_pd( (x0 + 7*n_elem_per_reg), zv[7] ); + + x0 += 8*n_elem_per_reg; + } + + case 2: + + for ( ; (i + 11) < n; i += 12 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + zv[1] = _mm256_mul_pd( alphav, xv[1] ); + zv[2] = _mm256_mul_pd( alphav, xv[2] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (x0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (x0 + 2*n_elem_per_reg), zv[2] ); + + x0 += 3*n_elem_per_reg; + } + + for ( ; (i + 3) < n; i += 4 ) + { + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + zv[0] = _mm256_mul_pd( alphav, xv[0] ); + + _mm256_storeu_pd( (x0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + } + + for ( ; (i + 0) < n; i += 1 ) + { + *x0 *= *alpha; + + x0 += 1; + } } } else { const double alphac = *alpha; - for ( i = 0; i < n; ++i ) + for ( ; i < n; ++i ) { *x0 *= alphac; From 62718f92fe792734988f8dd9a2af2a94957cf1bd Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Tue, 21 Dec 2021 16:49:11 +0530 Subject: [PATCH 13/63] Optimized AXPBYV Kernel using AVX2 Intrinsics Details: - Intrinsic implementation of axpbyv for AVX2 - Bench written for axpbyv - Added definitions in zen contexts AMD-Internal: [CPUPL-1963] Change-Id: I9bc21a6170f5c944eb6e9e9f0e994b9992f8b539 --- bench/Makefile | 12 +- bench/bench_axpbyv.c | 265 +++++++ bench/inputaxpbyv.txt | 40 + config/zen/bli_cntx_init_zen.c | 10 +- config/zen2/bli_cntx_init_zen2.c | 10 +- config/zen3/bli_cntx_init_zen3.c | 10 +- kernels/zen/1/CMakeLists.txt | 2 + kernels/zen/1/bli_axpbyv_zen_int.c | 1099 ++++++++++++++++++++++++++ kernels/zen/1/bli_axpbyv_zen_int10.c | 709 +++++++++++++++++ kernels/zen/bli_kernels_zen.h | 10 + 10 files changed, 2155 insertions(+), 12 deletions(-) create mode 100644 bench/bench_axpbyv.c create mode 100644 bench/inputaxpbyv.txt create mode 100644 kernels/zen/1/bli_axpbyv_zen_int.c create mode 100644 kernels/zen/1/bli_axpbyv_zen_int10.c diff --git a/bench/Makefile b/bench/Makefile index 3ee497212d..d47485b2fc 100755 --- a/bench/Makefile +++ b/bench/Makefile @@ -191,7 +191,8 @@ blis: \ bench_trsv_blis.x \ bench_amaxv_blis.x \ bench_copyv_blis.x \ - bench_swapv_blis.x + bench_swapv_blis.x \ + bench_axpbyv_blis.x openblas: \ bench_gemm_openblas.x \ @@ -205,7 +206,8 @@ openblas: \ bench_trsv_openblas.x \ bench_amaxv_openblas.x \ bench_copyv_openblas.x \ - bench_swapv_openblas.x + bench_swapv_openblas.x \ + bench_axpbyv_openblas.x atlas: \ bench_gemm_atlas.x \ @@ -219,7 +221,8 @@ atlas: \ bench_trsv_atlas.x \ bench_amaxv_atlas.x \ bench_copyv_atlas.x \ - bench_swapv_atlas.x + bench_swapv_atlas.x \ + bench_axpbyv_atlax.x mkl: \ bench_gemm_mkl.x \ @@ -233,7 +236,8 @@ mkl: \ bench_trsv_mkl.x \ bench_amaxv_mkl.x \ bench_copyv_mkl.x \ - bench_swapv_mkl.x + bench_swapv_mkl.x \ + bench_axpbyv_mkl.x # --Object file rules -- diff --git a/bench/bench_axpbyv.c b/bench/bench_axpbyv.c new file mode 100644 index 0000000000..36a203f696 --- /dev/null +++ b/bench/bench_axpbyv.c @@ -0,0 +1,265 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef WIN32 +#include +#else +#include +#endif +#include "blis.h" + +#ifndef DT +#define DT BLIS_DOUBLE +#endif +#define AOCL_MATRIX_INITIALISATION + +int main( int argc, char** argv ) +{ + obj_t x, y, alpha, beta; // BLIS objects + dim_t p_inc = 0; // To keep track of number of inputs + num_t dt; // BLIS datatype + char dt_ch; // {S, D, Z, C} from input + int r, n_repeats; // repetition counter; number of repeats + + double dtime; + double dtime_save; + double gflops; + + FILE* fin = NULL; // Input FILE* + FILE* fout = NULL; // Output FILE* + + n_repeats = N_REPEAT; // Fetched from Makefile + + dt = DT; // Set datatype as BLIS_DOUBLE + + if ( argc < 3 ) + { + printf( "Usage: ./bench_axpbyv_XX.x input.txt output.txt\n" ); + exit( 1 ); + } + + fin = fopen( argv[1], "r" ); // Open input file in read mode + if ( fin == NULL ) + { + printf( "Error opening input file %s\n", argv[1] ); + exit( 1 ); + } + + fout = fopen( argv[2], "w" ); // Open output file in write mode + if ( fout == NULL ) + { + printf( "Error opening output file %s\n", argv[2] ); + exit( 1 ); + } + +#ifdef DEBUG + fprintf( fout, "gflops\n" ); +#else + fprintf(fout, "Dt\t n\t alpha_r\t alpha_i\t beta_r\t beta_i\t gflops\n" ); +#endif + + dim_t n; // dimension + inc_t incx; // stride x + inc_t incy; // stride y + char tmp[256]; // to store function name, line not present in logs + double alpha_r, alpha_i, beta_r, beta_i; + + // {function name} {S, D, C, Z} {n} + // {alpha_r} {alpha_i} {incx} {beta_r} {beta_i} {incy} + while ( fscanf( fin, "%s %c %ld %lf %lf %ld %lf %lf %ld\n", + tmp, &dt_ch, &n, + &alpha_r, &alpha_i, &incx, &beta_r, &beta_i, &incy ) == 9 ) + { + if ( dt_ch == 'D' || dt_ch == 'd' ) dt = BLIS_DOUBLE; + else if ( dt_ch == 'Z' || dt_ch == 'z' ) dt = BLIS_DCOMPLEX; + else if ( dt_ch == 'S' || dt_ch == 's' ) dt = BLIS_FLOAT; + else if ( dt_ch == 'C' || dt_ch == 'c' ) dt = BLIS_SCOMPLEX; + else + { + printf( "Invalid data type %c\n", dt_ch ); + continue; + } + + // Creating BLIS objects + bli_obj_create( dt, n, 1, incx, 1, &x ); // For input vector x + bli_obj_create( dt, n, 1, incy, 1, &y ); // For input vector y + bli_obj_create( dt, 1, 1, 0, 0, &alpha); // For input vector alpha + bli_obj_create( dt, 1, 1, 0, 0, &beta); // For input vector beta + + #ifdef AOCL_MATRIX_INITIALISATION + bli_randm( &x ); + bli_randm( &y ); + #endif + + bli_setsc( alpha_r, alpha_i, &alpha ); + bli_setsc( beta_r, beta_i, &beta ); + + dtime_save = DBL_MAX; + + for ( r = 0; r < n_repeats; ++r ) + { + dtime = bli_clock(); + +#ifdef BLIS + bli_axpbyv( &alpha, &x, &beta, &y ); +#else + f77_int nn = bli_obj_length( &x ); + f77_int blas_incx = bli_obj_vector_inc( &x ); + f77_int blas_incy = bli_obj_vector_inc( &y ); + + if ( bli_is_float( dt ) ) + { + float* alphap = bli_obj_buffer( &alpha ); + float* xp = bli_obj_buffer( &x ); + float* betap = bli_obj_buffer( &beta ); + float* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_saxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + saxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_double( dt ) ) + { + double* alphap = bli_obj_buffer( &alpha ); + double* xp = bli_obj_buffer( &x ); + double* betap = bli_obj_buffer( &beta ); + double* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_daxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + daxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_scomplex( dt ) ) + { + scomplex* alphap = bli_obj_buffer( &alpha ); + scomplex* xp = bli_obj_buffer( &x ); + scomplex* betap = bli_obj_buffer( &beta ); + scomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_caxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + caxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } + else if ( bli_is_dcomplex( dt ) ) + { + dcomplex* alphap = bli_obj_buffer( &alpha ); + dcomplex* xp = bli_obj_buffer( &x ); + dcomplex* betap = bli_obj_buffer( &beta ); + dcomplex* yp = bli_obj_buffer( &y ); + +#ifdef CBLAS + cblas_zaxpby( nn, + *alphap, + xp, + blas_incx, + *betap, + yp, + blas_incy ); +#else + zaxpby_( &nn, + alphap, + xp, + &blas_incx, + betap, + yp, + &blas_incy ); +#endif + } +#endif + + dtime_save = bli_clock_min_diff( dtime_save, dtime ); + } + gflops = ( 3.0 * n ) / ( dtime_save * 1.0e9 ); + if ( bli_is_complex( dt ) ) gflops *= 4.0; + + printf( "data_axpbyv_%s", BLAS ); + + p_inc++; + printf( " %4lu [ %4lu %7.2f ];\n", + (unsigned long)(p_inc), + (unsigned long)n, + gflops ); + + fprintf( fout, "%c\t %ld\t %lf\t %lf\t %lf\t %lf\t %6.3f\n", + dt_ch, n, alpha_r, alpha_i, beta_r, beta_i, gflops ); + fflush( fout ); + + bli_obj_free( &x ); + bli_obj_free( &y ); + } + + return 0; +} \ No newline at end of file diff --git a/bench/inputaxpbyv.txt b/bench/inputaxpbyv.txt new file mode 100644 index 0000000000..3cfc7ae732 --- /dev/null +++ b/bench/inputaxpbyv.txt @@ -0,0 +1,40 @@ +saxpbyv_ S 32 0.900000 0.000000 1 0.900000 0.000000 1 +saxpbyv_ S 64 1.000000 0.000000 1 1.000000 0.000000 1 +saxpbyv_ S 100 -1 0.000000 1 -1 0.000000 1 +saxpbyv_ S 200 -1.100000 0.000000 1 -1.100000 0.000000 1 +saxpbyv_ S 300 1.100000 0.000000 1 1.100000 0.000000 1 +saxpbyv_ S 400 0.900000 0.000000 1 0.900000 0.000000 1 +saxpbyv_ S 500 1.000000 0.000000 1 1.000000 0.000000 1 +saxpbyv_ S 1000 -1 0.000000 1 -1 0.000000 1 +saxpbyv_ S 5000 -1.100000 0.000000 1 -1.100000 0.000000 1 +saxpbyv_ S 10000 1.100000 0.000000 1 1.100000 0.000000 1 +daxpbyv_ D 32 0.900000 0.000000 1 0.900000 0.000000 1 +daxpbyv_ D 64 1.000000 0.000000 1 1.000000 0.000000 1 +daxpbyv_ D 100 -1 0.000000 1 -1 0.000000 1 +daxpbyv_ D 200 -1.100000 0.000000 1 -1.100000 0.000000 1 +daxpbyv_ D 300 1.100000 0.000000 1 1.100000 0.000000 1 +daxpbyv_ D 400 0.900000 0.000000 1 0.900000 0.000000 1 +daxpbyv_ D 500 1.000000 0.000000 1 1.000000 0.000000 1 +daxpbyv_ D 1000 -1 0.000000 1 -1 0.000000 1 +daxpbyv_ D 5000 -1.100000 0.000000 1 -1.100000 0.000000 1 +daxpbyv_ D 10000 1.100000 0.000000 1 1.100000 0.000000 1 +caxpbyv_ C 32 0.900000 -1.100000 1 0.900000 -1.100000 1 +caxpbyv_ C 64 1.000000 1.100000 1 1.000000 1.100000 1 +caxpbyv_ C 100 -1 1.000000 1 -1 1 1 +caxpbyv_ C 200 -1.100000 0.900000 1 -1.100000 0.900000 1 +caxpbyv_ C 300 1.100000 1.000000 1 1.100000 1 1 +caxpbyv_ C 400 0.900000 -1.100000 1 0.900000 -1.100000 1 +caxpbyv_ C 500 1.000000 1.000000 1 1.000000 1 1 +caxpbyv_ C 1000 -1 0.900000 1 -1 0.900000 1 +caxpbyv_ C 5000 -1.100000 -1 1 -1.100000 -1 1 +caxpbyv_ C 10000 1.100000 -1 1 1.100000 -1 1 +zaxpbyv_ Z 32 0.900000 -1.100000 1 0.900000 -1.100000 1 +zaxpbyv_ Z 64 1.000000 1.100000 1 1.000000 1.100000 1 +zaxpbyv_ Z 100 -1 1.000000 1 -1 1 1 +zaxpbyv_ Z 200 -1.100000 0.900000 1 -1.100000 0.900000 1 +zaxpbyv_ Z 300 1.100000 1.000000 1 1.100000 1 1 +zaxpbyv_ Z 400 0.900000 -1.100000 1 0.900000 -1.100000 1 +zaxpbyv_ Z 500 1.000000 1.000000 1 1.000000 1 1 +zaxpbyv_ Z 1000 -1 0.900000 1 -1 0.900000 1 +zaxpbyv_ Z 5000 -1.100000 -1 1 -1.100000 -1 1 +zaxpbyv_ Z 10000 1.100000 -1 1 1.100000 -1 1 diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 7595849866..020e7052b9 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,12 +95,18 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, + // axpyv #if 0 BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 4f56316a7a..315362067e 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -3,7 +3,7 @@ An object-based framework for developing high-performance BLAS-like libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,13 +107,17 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - // axpyv + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index fc7dbcb808..ef47987454 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -107,13 +107,17 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 20, + 24, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, BLIS_AMAXV_KER, BLIS_DOUBLE, bli_damaxv_zen_int, #endif - // axpyv + // axpbyv + BLIS_AXPBYV_KER, BLIS_FLOAT, bli_saxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_DOUBLE, bli_daxpbyv_zen_int10, + BLIS_AXPBYV_KER, BLIS_SCOMPLEX, bli_caxpbyv_zen_int, + BLIS_AXPBYV_KER, BLIS_DCOMPLEX, bli_zaxpbyv_zen_int, // axpyv BLIS_AXPYV_KER, BLIS_FLOAT, bli_saxpyv_zen_int10, diff --git a/kernels/zen/1/CMakeLists.txt b/kernels/zen/1/CMakeLists.txt index 669a3ba89a..434be490d5 100644 --- a/kernels/zen/1/CMakeLists.txt +++ b/kernels/zen/1/CMakeLists.txt @@ -3,6 +3,8 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_amaxv_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpbyv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyv_zen_int10.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_copyv_zen_int.c diff --git a/kernels/zen/1/bli_axpbyv_zen_int.c b/kernels/zen/1/bli_axpbyv_zen_int.c new file mode 100644 index 0000000000..05ef96175a --- /dev/null +++ b/kernels/zen/1/bli_axpbyv_zen_int.c @@ -0,0 +1,1099 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union DS to access AVX registers */ +/* One 256-bit AVX register holds 8 SP elements */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* One 256-bit AVX register holds 4 DP elements */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/** + * saxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are single precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_saxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + const dim_t n_iter_unroll = 4; // num of registers per iteration + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t y0v, y1v, y2v, y3v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + return; + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_ss( alpha ); + betav.v = _mm256_broadcast_ss( beta ); + + // unrolling and vectorizing + for ( i = 0; ( i + 31 ) < n; i += 32 ) + { + // loading input y + y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_ps( betav.v, y0v.v ); + y1v.v = _mm256_mul_ps( betav.v, y1v.v ); + y2v.v = _mm256_mul_ps( betav.v, y2v.v ); + y3v.v = _mm256_mul_ps( betav.v, y3v.v ); + + // y := y' + alpha * x + y0v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + y3v.v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), y3v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_daxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + const dim_t n_iter_unroll = 4; // number of registers per iteration + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t y0v, y1v, y2v, y3v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_sd( alpha ); + betav.v = _mm256_broadcast_sd( beta ); + + // unrolling and vectorizing + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } +} + +/** + * caxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are simple complex vectors of length n. + * alpha & beta are scalers. + */ +void bli_caxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + float alphaR, alphaI, betaR, betaI; + + __m256 alphaRv; + __m256 alphaIv; + __m256 betaRv; + __m256 betaIv; + __m256 xv[4]; + __m256 yv[4]; + __m256 iv[4]; // intermediate registers + + conj_t conjx_use = conjx; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = ( float* ) x; + y0 = ( float* ) y; + + alphaR = alpha->real; + alphaI = alpha->imag; + betaR = beta->real; + betaI = beta->imag; + + if ( incx == 1 && incy == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // y = beta*y + alpha*x + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) + // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI + // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + + // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_NO_CONJUGATE + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR aR aR aR aR aR aR aR + // aiv = -aI aI -aI aI -aI aI -aI aI + // brv = bR bR bR bR bR bR bR bR + // biv = -bI bI -bI bI -bI bI -bI bI + + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // y = beta*y + alpha*conj(x) + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) + // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI + // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + + // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_CONJUGATE + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + // arv = aR -aR aR -aR aR -aR aR -aR + // aiv = aI aI aI aI aI aI aI aI + // brv = bR bR bR bR bR bR bR bR + // biv = -bI bI -bI bI -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + // broadcast alpha & beta to all elements of respective vector registers + if ( !bli_is_conj( conjx ) ) // If BLIS_NO_CONJUGATE + { + // alphaRv = aR aR aR aR aR aR aR aR + // alphaIv = -aI aI -aI aI -aI aI -aI aI + // betaRv = bR bR bR bR bR bR bR bR + // betaIv = -bI bI -bI bI -bI bI -bI bI + alphaRv = _mm256_broadcast_ss( &alphaR ); + alphaIv = _mm256_set_ps + ( + alphaI, -alphaI, alphaI, -alphaI, + alphaI, -alphaI, alphaI, -alphaI + ); + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_set_ps + ( + betaI, -betaI, betaI, -betaI, + betaI, -betaI, betaI, -betaI + ); + } + else + { + // alphaRv = aR -aR aR -aR aR -aR aR -aR + // alphaIv = aI aI aI aI aI aI aI aI + // betaRv = bR bR bR bR bR bR bR bR + // betaIv = -bI bI -bI bI -bI bI -bI bI + alphaRv = _mm256_set_ps + ( + -alphaR, alphaR, -alphaR, alphaR, + -alphaR, alphaR, -alphaR, alphaR + ); + alphaIv = _mm256_broadcast_ss( &alphaI ); + betaRv = _mm256_broadcast_ss( &betaR ); + betaIv = _mm256_set_ps + ( + betaI, -betaI, betaI, -betaI, + betaI, -betaI, betaI, -betaI + ); + } + + // Processing 16 elements per loop, 8 FMAs + for ( i = 0; ( i + 15 ) < n; i += 16 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + iv[2] = _mm256_mul_ps( betaRv, yv[2] ); + iv[3] = _mm256_mul_ps( betaRv, yv[3] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + yv[2] = _mm256_permute_ps( yv[2], 0xB1); + yv[3] = _mm256_permute_ps( yv[3], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); + yv[3] = _mm256_fmadd_ps( betaIv, yv[3], iv[3] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + iv[3] = _mm256_mul_ps( alphaRv, xv[3] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + xv[2] = _mm256_permute_ps( xv[2], 0xB1); + xv[3] = _mm256_permute_ps( xv[3], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_ps( alphaIv, xv[3], yv[3] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_ps( (y0 + 3*n_elem_per_reg), yv[3] ); + + y0 += 4*n_elem_per_reg; + x0 += 4*n_elem_per_reg; + } + + // Processing 12 elements per loop, 6 FMAs + for ( ; ( i + 11 ) < n; i += 12 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + iv[2] = _mm256_mul_ps( betaRv, yv[2] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + yv[2] = _mm256_permute_ps( yv[2], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_ps( betaIv, yv[2], iv[2] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + iv[2] = _mm256_mul_ps( alphaRv, xv[2] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + xv[2] = _mm256_permute_ps( xv[2], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_ps( alphaIv, xv[2], yv[2] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_ps( (y0 + 2*n_elem_per_reg), yv[2] ); + + y0 += 3*n_elem_per_reg; + x0 += 3*n_elem_per_reg; + } + + // Processing 16 elements per loop, 8 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // xv = xR1 xI1 xR2 xI2 xR3 xI3 xR4 xI4 + xv[0] = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 yR3 yI3 yR4 yI4 + yv[0] = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_ps( betaRv, yv[0] ); + iv[1] = _mm256_mul_ps( betaRv, yv[1] ); + + // yv' = yI1 yR1 yI2 yR2 yI3 yR3 yI4 yR4 + yv[0] = _mm256_permute_ps( yv[0], 0xB1); + yv[1] = _mm256_permute_ps( yv[1], 0xB1); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_ps( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_ps( betaIv, yv[1], iv[1] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_ps( alphaRv, xv[0] ); + iv[1] = _mm256_mul_ps( alphaRv, xv[1] ); + + // xv' = xI1 xR1 xI2 xR2 xI3 xR3 xI4 xR4 + xv[0] = _mm256_permute_ps( xv[0], 0xB1); + xv[1] = _mm256_permute_ps( xv[1], 0xB1); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_ps( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_ps( alphaIv, xv[1], yv[1] ); + + _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_ps( (y0 + 1*n_elem_per_reg), yv[1] ); + + y0 += 2*n_elem_per_reg; + x0 += 2*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx_use ) ) + { + for ( ; i < n ; ++i ) + { + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; i < n ; ++i ) + { + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + } + else + { + // for non-unit increments, use scaler code + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * zaxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double complex vectors of length n. + * alpha & beta are scalers. + */ +void bli_zaxpbyv_zen_int + ( + conj_t conjx, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + double alphaR, alphaI, betaR, betaI; + + __m256d alphaRv; + __m256d alphaIv; + __m256d betaRv; + __m256d betaIv; + __m256d xv[4]; + __m256d yv[4]; + __m256d iv[4]; // intermediate registers + + conj_t conjx_use = conjx; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( c, eq0 )( *alpha ) && PASTEMAC( c, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = ( double* ) x; + y0 = ( double* ) y; + + alphaR = alpha->real; + alphaI = alpha->imag; + betaR = beta->real; + betaI = beta->imag; + + if ( incx == 1 && incy == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // y = beta*y + alpha*x + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR + ixI ) + // y = bR.yR + ibR.yI + ibI.yR - ibIyI + aR.xR + iaR.xI + iaI.xR - aI.xI + // y = ( bR.yR - bI.yI + aR.xR - aI.xI ) + + // i ( bR.yI + bI.yR + aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_NO_CONJUGATE + // yv = yR1 yI1 yR2 yI2 + // yv' = yI1 yR1 yI2 yR2 + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR aR aR aR + // aiv = -aI aI -aI aI + // brv = bR bR bR bR + // biv = -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // y = beta*y + alpha*conj(x) + // y = ( bR + ibI ) * ( yR + iyI ) + ( aR + iaI ) * ( xR - ixI ) + // y = bR.yR + ibR.yI + ibI.yR - bI.yI + aR.xR - iaR.xI + iaI.xR + aI.xI + // y = ( bR.yR - bI.yI + aR.xR + aI.xI ) + + // i ( bR.yI + bI.yR - aR.xI + aI.xR ) + + // SIMD Algorithm BLIS_CONJUGATE + // yv = yR1 yI1 yR2 yI2 + // yv' = yI1 yR1 yI2 yR2 + // xv = xR1 xI1 xR2 xI2 + // xv' = xI1 xR1 xI2 xR2 + // arv = aR -aR aR -aR + // aiv = aI aI aI aI + // brv = bR bR bR bR + // biv = -bI bI -bI bI + // + // step 1: iv = brv * iv + // step 2: shuffle yv -> yv' + // step 3: FMA yv = biv * yv' + iv + // step 4: iv = arv * xv + // step 5: shuffle xv -> xv' + // step 6: FMA yv = aiv * xv' + iv + + // broadcast alpha & beta to all elements of respective vector registers + if ( !bli_is_conj( conjx ) ) + { + // alphaRv = aR aR aR aR + // alphaIv = -aI aI -aI aI + // betaRv = bR bR bR bR + // betaIv = -bI bI -bI bI + alphaRv = _mm256_broadcast_sd( &alphaR ); + alphaIv = _mm256_set_pd( alphaI, -alphaI, alphaI, -alphaI ); + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); + } + else + { + // alphaRv = aR -aR aR -aR + // alphaIv = aI aI aI aI + // betaRv = bR bR bR bR + // betaIv = -bI bI -bI bI + alphaRv = _mm256_set_pd( -alphaR, alphaR, -alphaR, alphaR ); + alphaIv = _mm256_broadcast_sd( &alphaI ); + betaRv = _mm256_broadcast_sd( &betaR ); + betaIv = _mm256_set_pd( betaI, -betaI, betaI, -betaI ); + } + + // Processing 8 elements per loop, 8 FMAs + for ( i = 0; ( i + 7 ) < n; i += 8 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + iv[3] = _mm256_mul_pd( betaRv, yv[3] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + yv[2] = _mm256_permute_pd( yv[2], 5); + yv[3] = _mm256_permute_pd( yv[3], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + yv[3] = _mm256_fmadd_pd( betaIv, yv[3], iv[3] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + iv[3] = _mm256_mul_pd( alphaRv, xv[3] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + xv[2] = _mm256_permute_pd( xv[2], 5); + xv[3] = _mm256_permute_pd( xv[3], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] ); + yv[3] = _mm256_fmadd_pd( alphaIv, xv[3], yv[3] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3] ); + + y0 += 4*n_elem_per_reg; + x0 += 4*n_elem_per_reg; + } + + // Processing 6 elements per loop, 6 FMAs + for ( ; ( i + 5 ) < n; i += 6 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + iv[2] = _mm256_mul_pd( betaRv, yv[2] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + yv[2] = _mm256_permute_pd( yv[2], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + yv[2] = _mm256_fmadd_pd( betaIv, yv[2], iv[2] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + iv[2] = _mm256_mul_pd( alphaRv, xv[2] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + xv[2] = _mm256_permute_pd( xv[2], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + yv[2] = _mm256_fmadd_pd( alphaIv, xv[2], yv[2] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2] ); + + y0 += 3*n_elem_per_reg; + x0 += 3*n_elem_per_reg; + } + + // Processing 4 elements per loop, 4 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + iv[1] = _mm256_mul_pd( betaRv, yv[1] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + yv[1] = _mm256_permute_pd( yv[1], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + yv[1] = _mm256_fmadd_pd( betaIv, yv[1], iv[1] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + iv[1] = _mm256_mul_pd( alphaRv, xv[1] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + xv[1] = _mm256_permute_pd( xv[1], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + yv[1] = _mm256_fmadd_pd( alphaIv, xv[1], yv[1] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1] ); + + y0 += 2*n_elem_per_reg; + x0 += 2*n_elem_per_reg; + } + + // Processing 2 elements per loop, 3 FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // xv = xR1 xI1 xR2 xI2 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // yv = yR1 yI1 yR2 yI2 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // iv = betaRv * yv + // = yR1.bR, yI1.bR, yR2.bR, yI2.bR, ... + iv[0] = _mm256_mul_pd( betaRv, yv[0] ); + + // yv' = yI1 yR1 yI2 yR2 + yv[0] = _mm256_permute_pd( yv[0], 5); + + // yv = betaIv * yv' + iv + // = yR1.bR - yI1.bI, yI1.bR + yR1.bI, ... + yv[0] = _mm256_fmadd_pd( betaIv, yv[0], iv[0] ); + + // iv = alphaRv * xv + // = xR1.aR, xI1.aR, xR2.aR, xI2.aR, ... + iv[0] = _mm256_mul_pd( alphaRv, xv[0] ); + + // xv' = xI1 xR1 xI2 xR2 + xv[0] = _mm256_permute_pd( xv[0], 5); + + // yv = alphaIv * xv + yv + // = yR1.bR - yR1.bI - xR1.aI, yI1.bR + yI1.bI + xI1.aI, ... + yv[0] = _mm256_fmadd_pd( alphaIv, xv[0], yv[0] ); + + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0] ); + + y0 += 1*n_elem_per_reg; + x0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx_use ) ) + { + for ( ; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + else + { + for ( ; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += 2; + y0 += 2; + } + } + } + else + { + // for non-unit increments, use scaler code + if ( !bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) - ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) + + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + else + { + for ( i = 0; i < n ; ++i ) + { + // yReal = ( bR.yR - bI.yI + aR.xR - aI.xI ) + *y0 = ( betaR * (*y0) ) - ( betaI * (*(y0 + 1)) ) + + ( alphaR * (*x0) ) + ( alphaI * (*(x0 + 1)) ); + // yImag = ( bR.yI + bI.yR + aR.xI + aI.xR ) + *(y0 + 1) = ( betaR * (*(y0 + 1)) ) + ( betaI * (*y0) ) - + ( alphaR * (*(x0 + 1)) ) + ( alphaI * (*x0) ); + + x0 += incx * 2; + y0 += incy * 2; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/1/bli_axpbyv_zen_int10.c b/kernels/zen/1/bli_axpbyv_zen_int10.c new file mode 100644 index 0000000000..bbfdaf0d6a --- /dev/null +++ b/kernels/zen/1/bli_axpbyv_zen_int10.c @@ -0,0 +1,709 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +/* Union DS to access AVX registers */ +/* One 256-bit AVX register holds 8 SP elements */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + +/* One 256-bit AVX register holds 4 DP elements */ +typedef union +{ + __m256d v; + double d[4] __attribute__((aligned(64))); +} v4df_t; + +/** + * saxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are single precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_saxpbyv_zen_int10 + ( + conj_t conjx, + dim_t n, + float* restrict alpha, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 8; // number of elements per register + + dim_t i; // iterator + + float* restrict x0; + float* restrict y0; + + v8sf_t alphav; + v8sf_t betav; + v8sf_t yv[10]; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_ss( alpha ); + betav.v = _mm256_broadcast_ss( beta ); + + // Processing 80 elements per loop, 10 FMAs + for ( i = 0; ( i + 79 ) < n; i += 80 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + yv[5].v = _mm256_loadu_ps( y0 + 5*n_elem_per_reg ); + yv[6].v = _mm256_loadu_ps( y0 + 6*n_elem_per_reg ); + yv[7].v = _mm256_loadu_ps( y0 + 7*n_elem_per_reg ); + yv[8].v = _mm256_loadu_ps( y0 + 8*n_elem_per_reg ); + yv[9].v = _mm256_loadu_ps( y0 + 9*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + yv[5].v = _mm256_mul_ps( betav.v, yv[5].v ); + yv[6].v = _mm256_mul_ps( betav.v, yv[6].v ); + yv[7].v = _mm256_mul_ps( betav.v, yv[7].v ); + yv[8].v = _mm256_mul_ps( betav.v, yv[8].v ); + yv[9].v = _mm256_mul_ps( betav.v, yv[9].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + yv[5].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 5*n_elem_per_reg ), + yv[5].v + ); + yv[6].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 6*n_elem_per_reg ), + yv[6].v + ); + yv[7].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 7*n_elem_per_reg ), + yv[7].v + ); + yv[8].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 8*n_elem_per_reg ), + yv[8].v + ); + yv[9].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 9*n_elem_per_reg ), + yv[9].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + _mm256_storeu_ps( ( y0 + 5*n_elem_per_reg ), yv[5].v ); + _mm256_storeu_ps( ( y0 + 6*n_elem_per_reg ), yv[6].v ); + _mm256_storeu_ps( ( y0 + 7*n_elem_per_reg ), yv[7].v ); + _mm256_storeu_ps( ( y0 + 8*n_elem_per_reg ), yv[8].v ); + _mm256_storeu_ps( ( y0 + 9*n_elem_per_reg ), yv[9].v ); + + x0 += 10 * n_elem_per_reg; + y0 += 10 * n_elem_per_reg; + } + + // Processing 40 elements per loop, 5 FMAs + for ( ; ( i + 39 ) < n; i += 40 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + yv[4].v = _mm256_loadu_ps( y0 + 4*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + yv[4].v = _mm256_mul_ps( betav.v, yv[4].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + yv[4].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 4*n_elem_per_reg ), + yv[4].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + _mm256_storeu_ps( ( y0 + 4*n_elem_per_reg ), yv[4].v ); + + x0 += 5 * n_elem_per_reg; + y0 += 5 * n_elem_per_reg; + } + + // Processing 32 elements per loop, 4 FMAs + for ( ; ( i + 31 ) < n; i += 32 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + yv[2].v = _mm256_mul_ps( betav.v, yv[2].v ); + yv[3].v = _mm256_mul_ps( betav.v, yv[3].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + yv[2].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 2*n_elem_per_reg ), + yv[2].v + ); + yv[3].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 3*n_elem_per_reg ), + yv[3].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + _mm256_storeu_ps( ( y0 + 2*n_elem_per_reg ), yv[2].v ); + _mm256_storeu_ps( ( y0 + 3*n_elem_per_reg ), yv[3].v ); + + x0 += 4 * n_elem_per_reg; + y0 += 4 * n_elem_per_reg; + } + + // Processing 16 elements per loop, 2 FMAs + for ( ; ( i + 15 ) < n; i += 16 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + yv[1].v = _mm256_mul_ps( betav.v, yv[1].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + yv[1].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 1*n_elem_per_reg ), + yv[1].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + _mm256_storeu_ps( ( y0 + 1*n_elem_per_reg ), yv[1].v ); + + x0 += 2 * n_elem_per_reg; + y0 += 2 * n_elem_per_reg; + } + + // Processing 8 elements per loop, 1 FMA + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input values + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + yv[0].v = _mm256_mul_ps( betav.v, yv[0].v ); + + // y := y' + alpha * x + yv[0].v = _mm256_fmadd_ps + ( + alphav.v, + _mm256_loadu_ps( x0 + 0*n_elem_per_reg ), + yv[0].v + ); + + // storing the output + _mm256_storeu_ps( ( y0 + 0*n_elem_per_reg ), yv[0].v ); + + x0 += 1 * n_elem_per_reg; + y0 += 1 * n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; i++ ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} + +/** + * daxpbyv kernel performs the axpbyv operation. + * y := beta * y + alpha * conjx(x) + * where, + * x & y are double precision vectors of length n. + * alpha & beta are scalers. + */ +void bli_daxpbyv_zen_int10 + ( + conj_t conjx, + dim_t n, + double* restrict alpha, + double* restrict x, inc_t incx, + double* restrict beta, + double* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + const dim_t n_elem_per_reg = 4; // number of elements per register + const dim_t n_iter_unroll = 10; // number of registers per iteration + + dim_t i; // iterator + + double* restrict x0; + double* restrict y0; + + v4df_t alphav; + v4df_t betav; + v4df_t y0v, y1v, y2v, y3v, y4v, y5v, y6v, y7v, y8v, y9v; + + /* if the vector dimension is zero, or if alpha & beta are zero, + return early. */ + if ( bli_zero_dim1( n ) || + ( PASTEMAC( s, eq0 )( *alpha ) && PASTEMAC( s, eq0 )( *beta ) ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + // initialize local pointers + x0 = x; + y0 = y; + + if ( incx == 1 && incy == 1 ) + { + // broadcast alpha & beta to all elements of respective vector registers + alphav.v = _mm256_broadcast_sd( alpha ); + betav.v = _mm256_broadcast_sd( beta ); + + // Using 10 FMAs per loop + for ( i = 0; ( i + 39 ) < n; i += 40 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + y5v.v = _mm256_loadu_pd( y0 + 5*n_elem_per_reg ); + y6v.v = _mm256_loadu_pd( y0 + 6*n_elem_per_reg ); + y7v.v = _mm256_loadu_pd( y0 + 7*n_elem_per_reg ); + y8v.v = _mm256_loadu_pd( y0 + 8*n_elem_per_reg ); + y9v.v = _mm256_loadu_pd( y0 + 9*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + y4v.v = _mm256_mul_pd( betav.v, y4v.v ); + y5v.v = _mm256_mul_pd( betav.v, y5v.v ); + y6v.v = _mm256_mul_pd( betav.v, y6v.v ); + y7v.v = _mm256_mul_pd( betav.v, y7v.v ); + y8v.v = _mm256_mul_pd( betav.v, y8v.v ); + y9v.v = _mm256_mul_pd( betav.v, y9v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + y4v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + y4v.v + ); + y5v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 5*n_elem_per_reg ), + y5v.v + ); + y6v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 6*n_elem_per_reg ), + y6v.v + ); + y7v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 7*n_elem_per_reg ), + y7v.v + ); + y8v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 8*n_elem_per_reg ), + y8v.v + ); + y9v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 9*n_elem_per_reg ), + y9v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); + _mm256_storeu_pd( ( y0 + 5*n_elem_per_reg ), y5v.v ); + _mm256_storeu_pd( ( y0 + 6*n_elem_per_reg ), y6v.v ); + _mm256_storeu_pd( ( y0 + 7*n_elem_per_reg ), y7v.v ); + _mm256_storeu_pd( ( y0 + 8*n_elem_per_reg ), y8v.v ); + _mm256_storeu_pd( ( y0 + 9*n_elem_per_reg ), y9v.v ); + + x0 += n_elem_per_reg * n_iter_unroll; + y0 += n_elem_per_reg * n_iter_unroll; + } + + // Using 5 FMAs per loop + for ( ; ( i + 19 ) < n; i += 20 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + y4v.v = _mm256_loadu_pd( y0 + 4*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + y2v.v = _mm256_mul_pd( betav.v, y2v.v ); + y3v.v = _mm256_mul_pd( betav.v, y3v.v ); + y4v.v = _mm256_mul_pd( betav.v, y4v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + y2v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 2*n_elem_per_reg ), + y2v.v + ); + y3v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 3*n_elem_per_reg ), + y3v.v + ); + y4v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 4*n_elem_per_reg ), + y4v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + _mm256_storeu_pd( ( y0 + 2*n_elem_per_reg ), y2v.v ); + _mm256_storeu_pd( ( y0 + 3*n_elem_per_reg ), y3v.v ); + _mm256_storeu_pd( ( y0 + 4*n_elem_per_reg ), y4v.v ); + + x0 += n_elem_per_reg * 5; + y0 += n_elem_per_reg * 5; + } + + // Using 2 FMAs per loop + for ( ; ( i + 7 ) < n; i += 8 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + y1v.v = _mm256_mul_pd( betav.v, y1v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + y1v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 1*n_elem_per_reg ), + y1v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + _mm256_storeu_pd( ( y0 + 1*n_elem_per_reg ), y1v.v ); + + x0 += n_elem_per_reg * 2; + y0 += n_elem_per_reg * 2; + } + + // Using 1 FMAs per loop + for ( ; ( i + 3 ) < n; i += 4 ) + { + // loading input y + y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // y' := y := beta * y + y0v.v = _mm256_mul_pd( betav.v, y0v.v ); + + // y := y' + alpha * x + // := beta * y + alpha * x + y0v.v = _mm256_fmadd_pd + ( + alphav.v, + _mm256_loadu_pd( x0 + 0*n_elem_per_reg ), + y0v.v + ); + + // storing the output + _mm256_storeu_pd( ( y0 + 0*n_elem_per_reg ), y0v.v ); + + x0 += n_elem_per_reg * 1; + y0 += n_elem_per_reg * 1; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // if there are leftover iterations, perform them with scaler code + for ( ; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + else + { + // for non-unit increments, use scaler code + for ( i = 0; i < n; ++i ) + { + *y0 = ( (*alpha) * (*x0) ) + ( (*beta) * (*y0) ); + + x0 += incx; + y0 += incy; + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index e8cbe49d15..42a92809c2 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -54,6 +54,16 @@ PACKM_KER_PROT(double, d, packm_6xk_nn_zen) AMAXV_KER_PROT( float, s, amaxv_zen_int ) AMAXV_KER_PROT( double, d, amaxv_zen_int ) +// axpbyv (intrinsics) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int ) +AXPBYV_KER_PROT( scomplex, c, axpbyv_zen_int ) +AXPBYV_KER_PROT( dcomplex, z, axpbyv_zen_int ) + +// axpbyv (intrinsics, unrolled x10) +AXPBYV_KER_PROT( float, s, axpbyv_zen_int10 ) +AXPBYV_KER_PROT( double, d, axpbyv_zen_int10 ) + // axpyv (intrinsics) AXPYV_KER_PROT( float, s, axpyv_zen_int ) AXPYV_KER_PROT( double, d, axpyv_zen_int ) From b8400f95eeb345e088ddaea1ba0e8efd6410a2cd Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Tue, 7 Dec 2021 00:56:16 -0600 Subject: [PATCH 14/63] Optimized ztrsv implementation - Implemented alternate method of performing multiplication and addition operations on double precision complex datatype by separating out real and imaginary parts of complex number. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1969] Change-Id: Ib181f193c05740d5f6b9de3930e1995dea4a50f2 --- kernels/zen/1f/bli_axpyf_zen_int_5.c | 891 ++++++++++++++++----------- 1 file changed, 528 insertions(+), 363 deletions(-) diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index f770389196..1125197775 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1747,8 +1747,17 @@ void bli_caxpyf_zen_int_5 } -// ----------------------------------------------------------------------------- - +//------------------------------------------------------------------------------ +/** + * Following kernel performs axpyf operation on dcomplex data. + * Operate over 5 columns of a matrix at a time and march through + * rows in steps of 4 or 2. + * For optimal performance, it separate outs imaginary and real + * components of chis and broadcast them into separate ymm vector + * registers. + * By doing so it avoids necessity of permute operation to get the + * final result of dcomp-lex multiplication. + */ void bli_zaxpyf_zen_int_5 ( conj_t conja, @@ -1762,391 +1771,547 @@ void bli_zaxpyf_zen_int_5 cntx_t* restrict cntx ) { - const dim_t fuse_fac = 5; + const dim_t fuse_fac = 5; - const dim_t n_elem_per_reg = 2; - const dim_t n_iter_unroll = 2; + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 2; - dim_t i = 0; - dim_t setPlusOne = 1; + dim_t i = 0; + dim_t setPlusOne = 1; - v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; - v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; + v4df_t chi0v, chi1v, chi2v, chi3v, chi4v; + v4df_t chi5v, chi6v, chi7v, chi8v, chi9v; - v4df_t a00v, a01v, a02v, a03v, a04v; - v4df_t a05v, a06v, a07v, a08v, a09v; + v4df_t a00v, a01v, a02v, a03v, a04v; - v4df_t a10v, a11v, a12v, a13v, a14v; - v4df_t a15v, a16v, a17v, a18v, a19v; + v4df_t a10v, a11v, a12v, a13v, a14v; - v4df_t y0v, y1v; - v4df_t setMinus, setPlus; + v4df_t y0v, y1v, y2v, y3v; + v4df_t r0v, r1v, conjv; - dcomplex chi0, chi1, chi2, chi3, chi4; - dcomplex* restrict a0; - dcomplex* restrict a1; - dcomplex* restrict a2; - dcomplex* restrict a3; - dcomplex* restrict a4; + dcomplex chi0, chi1, chi2, chi3, chi4; + dcomplex* restrict a0; + dcomplex* restrict a1; + dcomplex* restrict a2; + dcomplex* restrict a3; + dcomplex* restrict a4; - dcomplex* restrict y0; + dcomplex* restrict y0; - if ( bli_is_conj(conja) ){ - setPlusOne = -1; - } + if ( bli_is_conj(conja) ){ + setPlusOne = -1; + } - // If either dimension is zero, or if alpha is zero, return early. - if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; + // If either dimension is zero, or if alpha is zero, return early. + if ( bli_zero_dim2( m, b_n ) || bli_zeq0( *alpha ) ) return; - // If b_n is not equal to the fusing factor, then perform the entire - // operation as a loop over axpyv. - if ( b_n != fuse_fac ) - { + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over axpyv. + if ( b_n != fuse_fac ) + { #ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + bli_zaxpyv_zen_int5 + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } #else - zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); - - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - f - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } + zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); + + for ( i = 0; i < b_n; ++i ) + { + dcomplex* a1 = a + (0 )*inca + (i )*lda; + dcomplex* chi1 = x + (i )*incx; + dcomplex* y1 = y + (0 )*incy; + dcomplex alpha_chi1; + + bli_zcopycjs( conjx, *chi1, alpha_chi1 ); + bli_zscals( *alpha, alpha_chi1 ); + + f + ( + conja, + m, + &alpha_chi1, + a1, inca, + y1, incy, + cntx + ); + } #endif - return; - } - - - // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - y0 = y; - - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - - dcomplex *pchi0 = x + 0*incx ; - dcomplex *pchi1 = x + 1*incx ; - dcomplex *pchi2 = x + 2*incx ; - dcomplex *pchi3 = x + 3*incx ; - dcomplex *pchi4 = x + 4*incx ; - - bli_zcopycjs( conjx, *pchi0, chi0 ); - bli_zcopycjs( conjx, *pchi1, chi1 ); - bli_zcopycjs( conjx, *pchi2, chi2 ); - bli_zcopycjs( conjx, *pchi3, chi3 ); - bli_zcopycjs( conjx, *pchi4, chi4 ); - - // Scale each chi scalar by alpha. - bli_zscals( *alpha, chi0 ); - bli_zscals( *alpha, chi1 ); - bli_zscals( *alpha, chi2 ); - bli_zscals( *alpha, chi3 ); - bli_zscals( *alpha, chi4 ); - - // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0.real ); - chi1v.v = _mm256_broadcast_sd( &chi1.real ); - chi2v.v = _mm256_broadcast_sd( &chi2.real ); - chi3v.v = _mm256_broadcast_sd( &chi3.real ); - chi4v.v = _mm256_broadcast_sd( &chi4.real ); - - chi5v.v = _mm256_broadcast_sd( &chi0.imag ); - chi6v.v = _mm256_broadcast_sd( &chi1.imag ); - chi7v.v = _mm256_broadcast_sd( &chi2.imag ); - chi8v.v = _mm256_broadcast_sd( &chi3.imag ); - chi9v.v = _mm256_broadcast_sd( &chi4.imag ); - - // If there are vectorized iterations, perform them with vector - // instructions. - if ( inca == 1 && incy == 1 ) - { - setMinus.v = _mm256_set_pd( -1, 1, -1, 1 ); - - setPlus.v = _mm256_set1_pd( 1 ); - if ( bli_is_conj(conja) ){ - setPlus.v = _mm256_set_pd( -1, 1, -1, 1 ); - } - - /* - y := y + alpha * conja(A) * conjx(x) - - nn - (ar + ai) (xr + xi) - ar * xr - ai * xi - ar * xi + ai * xr - - cc : (ar - ai) (xr - xi) - ar * xr - ai * xi - -(ar * xi + ai * xr) - - nc : (ar + ai) (xr - xi) - ar * xr + ai * xi - -(ar * xi - ai * xr) - - cn : (ar - ai) (xr + xi) - ar * xr + ai * xi - ar * xi - ai * xr - - */ - - for( i = 0; (i + 3) < m; i += 4 ) - { - // Load the input values. - y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); - y1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); - - a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); - - a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); - - a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); - a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); - - a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); - a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); - - a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); - - a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); - a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); - a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); - a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); - a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); - - a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); - a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); - a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); - a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); - a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); - - a05v.v = _mm256_permute_pd( a05v.v, 5 ); - a06v.v = _mm256_permute_pd( a06v.v, 5 ); - a07v.v = _mm256_permute_pd( a07v.v, 5 ); - a08v.v = _mm256_permute_pd( a08v.v, 5 ); - a09v.v = _mm256_permute_pd( a09v.v, 5 ); - - a10v.v = _mm256_mul_pd( a10v.v, setPlus.v ); - a11v.v = _mm256_mul_pd( a11v.v, setPlus.v ); - a12v.v = _mm256_mul_pd( a12v.v, setPlus.v ); - a13v.v = _mm256_mul_pd( a13v.v, setPlus.v ); - a14v.v = _mm256_mul_pd( a14v.v, setPlus.v ); - - a15v.v = _mm256_mul_pd( a10v.v, setMinus.v ); - a16v.v = _mm256_mul_pd( a11v.v, setMinus.v ); - a17v.v = _mm256_mul_pd( a12v.v, setMinus.v ); - a18v.v = _mm256_mul_pd( a13v.v, setMinus.v ); - a19v.v = _mm256_mul_pd( a14v.v, setMinus.v ); - - a15v.v = _mm256_permute_pd( a15v.v, 5 ); - a16v.v = _mm256_permute_pd( a16v.v, 5 ); - a17v.v = _mm256_permute_pd( a17v.v, 5 ); - a18v.v = _mm256_permute_pd( a18v.v, 5 ); - a19v.v = _mm256_permute_pd( a19v.v, 5 ); - - // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - - y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); + return; + } + + + // At this point, we know that b_n is exactly equal to the fusing factor. + + a0 = a + 0*lda; + a1 = a + 1*lda; + a2 = a + 2*lda; + a3 = a + 3*lda; + a4 = a + 4*lda; + y0 = y; + + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + chi4 = *( x + 4*incx ); + + dcomplex *pchi0 = x + 0*incx ; + dcomplex *pchi1 = x + 1*incx ; + dcomplex *pchi2 = x + 2*incx ; + dcomplex *pchi3 = x + 3*incx ; + dcomplex *pchi4 = x + 4*incx ; + + bli_zcopycjs( conjx, *pchi0, chi0 ); + bli_zcopycjs( conjx, *pchi1, chi1 ); + bli_zcopycjs( conjx, *pchi2, chi2 ); + bli_zcopycjs( conjx, *pchi3, chi3 ); + bli_zcopycjs( conjx, *pchi4, chi4 ); + + // Scale each chi scalar by alpha. + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + bli_zscals( *alpha, chi4 ); + + // Broadcast the (alpha*chi?) scalars to all elements of vector registers. + chi0v.v = _mm256_broadcast_sd( &chi0.real ); + chi1v.v = _mm256_broadcast_sd( &chi1.real ); + chi2v.v = _mm256_broadcast_sd( &chi2.real ); + chi3v.v = _mm256_broadcast_sd( &chi3.real ); + chi4v.v = _mm256_broadcast_sd( &chi4.real ); + + chi5v.v = _mm256_broadcast_sd( &chi0.imag ); + chi6v.v = _mm256_broadcast_sd( &chi1.imag ); + chi7v.v = _mm256_broadcast_sd( &chi2.imag ); + chi8v.v = _mm256_broadcast_sd( &chi3.imag ); + chi9v.v = _mm256_broadcast_sd( &chi4.imag ); + + // If there are vectorized iterations, perform them with vector + // instructions. + if ( inca == 1 && incy == 1 ) + { + // March through vectors in multiple of 4. + for( i = 0; (i + 3) < m; i += 4 ) + { + // Load the input values. + r0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + r1v.v = _mm256_loadu_pd( (double*) (y0 + 1*n_elem_per_reg )); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + if ( bli_is_conj(conja) ){ + /** + * For conjugate cases imaginary part + * is negated. + */ + conjv.v = _mm256_set_pd( -1, 1, -1, 1 ); + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + a00v.v = _mm256_mul_pd(a00v.v, conjv.v); + a10v.v = _mm256_mul_pd(a10v.v, conjv.v); + a01v.v = _mm256_mul_pd(a01v.v, conjv.v); + a11v.v = _mm256_mul_pd(a11v.v, conjv.v); + a02v.v = _mm256_mul_pd(a02v.v, conjv.v); + a12v.v = _mm256_mul_pd(a12v.v, conjv.v); + a03v.v = _mm256_mul_pd(a03v.v, conjv.v); + a13v.v = _mm256_mul_pd(a13v.v, conjv.v); + a04v.v = _mm256_mul_pd(a04v.v, conjv.v); + a14v.v = _mm256_mul_pd(a14v.v, conjv.v); + } + else + { + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); + a10v.v = _mm256_loadu_pd( (double*) (a0 + 1*n_elem_per_reg )); + + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); + a11v.v = _mm256_loadu_pd( (double*) (a1 + 1*n_elem_per_reg )); + + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + a12v.v = _mm256_loadu_pd( (double*) (a2 + 1*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + a13v.v = _mm256_loadu_pd( (double*) (a3 + 1*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); + a14v.v = _mm256_loadu_pd( (double*) (a4 + 1*n_elem_per_reg )); + + } + + // perform : y += alpha * x; + /** + * chi[x]v.v holds real part of chi. + * chi[x]v.v holds imag part of chi. + * ys holds following computation: + * + * a[xx]v.v R1 I1 R2 I2 + * chi[x]v.v chi_R chi_R chi_R chi_R + * chi[x]v.v chi_I chi_I chi_I chi_I + * y[x]v.v R1*chi_R I1*chi_R R2*chi_R I2*chiR (compute with chi-real part) + * y[x]v.v R1*chi_I I1*chi_I R2*chi_I I2*chiI (compute with chi-imag part) + * + */ + y0v.v = _mm256_mul_pd( a00v.v, chi0v.v); + y1v.v = _mm256_mul_pd( a10v.v, chi0v.v); + + y2v.v = _mm256_mul_pd( a00v.v, chi5v.v); + y3v.v = _mm256_mul_pd( a10v.v, chi5v.v); + + /** + * y0v.v & y1v.v holds computation with real part of chi. + * y2v.v & y3v.v holds computaion with imag part of chi. + * Permute will swap the positions of elements in y2v.v & y3v.v + * as we need to perform: [ R*R + I*I & R*I + I*R]. + * Once dcomplex multiplication is done add the result into r0v.v + * r1v.v which holds axpy result of current tile which is being + * computed. + */ + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + /** + * Repeat the same computation as above + * for remaining tile. + */ + y0v.v = _mm256_mul_pd( a01v.v, chi1v.v ); + y1v.v = _mm256_mul_pd( a11v.v, chi1v.v ); + + y2v.v = _mm256_mul_pd( a01v.v, chi6v.v ); + y3v.v = _mm256_mul_pd( a11v.v, chi6v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a02v.v, chi2v.v); + y1v.v = _mm256_mul_pd( a12v.v, chi2v.v); + + y2v.v = _mm256_mul_pd( a02v.v, chi7v.v ); + y3v.v = _mm256_mul_pd( a12v.v, chi7v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a03v.v, chi3v.v ); + y1v.v = _mm256_mul_pd( a13v.v, chi3v.v ); + + y2v.v = _mm256_mul_pd( a03v.v, chi8v.v ); + y3v.v = _mm256_mul_pd( a13v.v, chi8v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + y0v.v = _mm256_setzero_pd(); + y1v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + y3v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a04v.v, chi4v.v ); + y1v.v = _mm256_mul_pd( a14v.v, chi4v.v ); + + y2v.v = _mm256_mul_pd( a04v.v, chi9v.v ); + y3v.v = _mm256_mul_pd( a14v.v, chi9v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y3v.v = _mm256_permute_pd(y3v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + y1v.v = _mm256_addsub_pd(y1v.v, y3v.v); + + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + r1v.v = _mm256_add_pd(y1v.v, r1v.v); + + /** + * Final axpy compuation is available in r0v.v + * and r1v.v registers. + * Store it back into y vector. + */ + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), r0v.v ); + _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), r1v.v ); + + /** + * Set the pointers next vectors elements to be + * computed based on unroll factor. + */ + y0 += n_elem_per_reg * n_iter_unroll; + a0 += n_elem_per_reg * n_iter_unroll; + a1 += n_elem_per_reg * n_iter_unroll; + a2 += n_elem_per_reg * n_iter_unroll; + a3 += n_elem_per_reg * n_iter_unroll; + a4 += n_elem_per_reg * n_iter_unroll; + } + // March through vectors in multiple of 2. + for( ; (i + 1) < m; i += 2 ) + { + r0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); - // For next 4 elements perform : y += alpha * x; - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + if ( bli_is_conj(conja) ){ + conjv.v = _mm256_set_pd( -1, 1, -1, 1 ); + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - y1v.v = _mm256_fmadd_pd( a15v.v, chi5v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a16v.v, chi6v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a17v.v, chi7v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a18v.v, chi8v.v, y1v.v ); - y1v.v = _mm256_fmadd_pd( a19v.v, chi9v.v, y1v.v ); + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - // Store the output. - _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double*) (y0 + 1*n_elem_per_reg), y1v.v ); + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); - y0 += n_elem_per_reg * n_iter_unroll; - a0 += n_elem_per_reg * n_iter_unroll; - a1 += n_elem_per_reg * n_iter_unroll; - a2 += n_elem_per_reg * n_iter_unroll; - a3 += n_elem_per_reg * n_iter_unroll; - a4 += n_elem_per_reg * n_iter_unroll; - } - for( ; (i + 1) < m; i += 2 ) - { - // Load the input values. - y0v.v = _mm256_loadu_pd( (double*) (y0 + 0*n_elem_per_reg )); - - a00v.v = _mm256_loadu_pd( (double*)(a0 + 0*n_elem_per_reg) ); - a01v.v = _mm256_loadu_pd( (double*)(a1 + 0*n_elem_per_reg) ); - a02v.v = _mm256_loadu_pd( (double*)(a2 + 0*n_elem_per_reg) ); - a03v.v = _mm256_loadu_pd( (double*)(a3 + 0*n_elem_per_reg) ); - a04v.v = _mm256_loadu_pd( (double*)(a4 + 0*n_elem_per_reg) ); - - a00v.v = _mm256_mul_pd( a00v.v, setPlus.v ); - a01v.v = _mm256_mul_pd( a01v.v, setPlus.v ); - a02v.v = _mm256_mul_pd( a02v.v, setPlus.v ); - a03v.v = _mm256_mul_pd( a03v.v, setPlus.v ); - a04v.v = _mm256_mul_pd( a04v.v, setPlus.v ); - - a05v.v = _mm256_mul_pd( a00v.v, setMinus.v ); - a06v.v = _mm256_mul_pd( a01v.v, setMinus.v ); - a07v.v = _mm256_mul_pd( a02v.v, setMinus.v ); - a08v.v = _mm256_mul_pd( a03v.v, setMinus.v ); - a09v.v = _mm256_mul_pd( a04v.v, setMinus.v ); - - a05v.v = _mm256_permute_pd( a05v.v, 5 ); - a06v.v = _mm256_permute_pd( a06v.v, 5 ); - a07v.v = _mm256_permute_pd( a07v.v, 5 ); - a08v.v = _mm256_permute_pd( a08v.v, 5 ); - a09v.v = _mm256_permute_pd( a09v.v, 5 ); + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); - // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - y0v.v = _mm256_fmadd_pd( a05v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a06v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a07v.v, chi7v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a08v.v, chi8v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a09v.v, chi9v.v, y0v.v ); + a00v.v = _mm256_mul_pd(a00v.v, conjv.v); + a01v.v = _mm256_mul_pd(a01v.v, conjv.v); + a02v.v = _mm256_mul_pd(a02v.v, conjv.v); + a03v.v = _mm256_mul_pd(a03v.v, conjv.v); + a04v.v = _mm256_mul_pd(a04v.v, conjv.v); + } + else + { + a00v.v = _mm256_loadu_pd( (double*) (a0 + 0*n_elem_per_reg )); - // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); + a01v.v = _mm256_loadu_pd( (double*) (a1 + 0*n_elem_per_reg )); - y0 += n_elem_per_reg ; - a0 += n_elem_per_reg ; - a1 += n_elem_per_reg ; - a2 += n_elem_per_reg ; - a3 += n_elem_per_reg ; - a4 += n_elem_per_reg ; - } - // If there are leftover iterations, perform them with scalar code. - for ( ; (i + 0) < m ; ++i ) - { - dcomplex y0c = *y0; - - const dcomplex a0c = *a0; - const dcomplex a1c = *a1; - const dcomplex a2c = *a2; - const dcomplex a3c = *a3; - const dcomplex a4c = *a4; - - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; - y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + a02v.v = _mm256_loadu_pd( (double*) (a2 + 0*n_elem_per_reg )); + + a03v.v = _mm256_loadu_pd( (double*) (a3 + 0*n_elem_per_reg )); + + a04v.v = _mm256_loadu_pd( (double*) (a4 + 0*n_elem_per_reg )); - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; - y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; - - *y0 = y0c; - - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; - y0 += 1; - } - } - else - { - for ( ; (i + 0) < m ; ++i ) - { - dcomplex y0c = *y0; - - const dcomplex a0c = *a0; - const dcomplex a1c = *a1; - const dcomplex a2c = *a2; - const dcomplex a3c = *a3; - const dcomplex a4c = *a4; - - y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; - y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; - y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; - y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; - y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; - - y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; - y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; - y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; - y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; - y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; - - *y0 = y0c; - - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - y0 += incy; - } - - } + } + + // perform : y += alpha * x; + /** + * chi[x]v.v holds real part of chi. + * chi[x]v.v holds imag part of chi. + * ys holds following computation: + * + * a[xx]v.v R1 I1 R2 I2 + * chi[x]v.v chi_R chi_R chi_R chi_R + * chi[x]v.v chi_I chi_I chi_I chi_I + * y[x]v.v R1*chi_R I1*chi_R R2*chi_R I2*chiR (compute with chi-real part) + * y[x]v.v R1*chi_I I1*chi_I R2*chi_I I2*chiI (compute with chi-imag part) + * + */ + y0v.v = _mm256_mul_pd( a00v.v, chi0v.v ); + y2v.v = _mm256_mul_pd( a00v.v, chi5v.v ); + + /** + * y0v.v holds computation with real part of chi. + * y2v.v holds computaion with imag part of chi. + * Permute will swap the positions of elements in y2v.v. + * as we need to perform: [ R*R + I*I & R*I + I*R]. + * Once dcomplex multiplication is done add the result into r0v.v + * which holds axpy result of current tile which is being + * computed. + */ + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + /** + * Repeat the same computation as above + * for remaining tile. + */ + y0v.v = _mm256_mul_pd( a01v.v, chi1v.v ); + y2v.v = _mm256_mul_pd( a01v.v, chi6v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a02v.v, chi2v.v ); + y2v.v = _mm256_mul_pd( a02v.v, chi7v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a03v.v, chi3v.v ); + y2v.v = _mm256_mul_pd( a03v.v, chi8v.v ); + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + y0v.v = _mm256_setzero_pd(); + y2v.v = _mm256_setzero_pd(); + + + y0v.v = _mm256_mul_pd( a04v.v, chi4v.v ); + y2v.v = _mm256_mul_pd( a04v.v, chi9v.v ); + + + y2v.v = _mm256_permute_pd(y2v.v, 0x5); + y0v.v = _mm256_addsub_pd(y0v.v, y2v.v); + r0v.v = _mm256_add_pd(y0v.v, r0v.v); + + /** + * Final axpy compuation is available in r0v.v + * Store it back into y vector. + */ + _mm256_storeu_pd( (double*) (y0 + 0*n_elem_per_reg), r0v.v ); + + y0 += n_iter_unroll; + a0 += n_iter_unroll; + a1 += n_iter_unroll; + a2 += n_iter_unroll; + a3 += n_iter_unroll; + a4 += n_iter_unroll; + + } + + // If there are leftover iterations, perform them with scalar code. + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += 1; + a1 += 1; + a2 += 1; + a3 += 1; + a4 += 1; + y0 += 1; + } + } + else + { + for ( ; (i + 0) < m ; ++i ) + { + dcomplex y0c = *y0; + + const dcomplex a0c = *a0; + const dcomplex a1c = *a1; + const dcomplex a2c = *a2; + const dcomplex a3c = *a3; + const dcomplex a4c = *a4; + + y0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag * setPlusOne; + y0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag * setPlusOne; + y0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag * setPlusOne; + y0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag * setPlusOne; + y0c.real += chi4.real * a4c.real - chi4.imag * a4c.imag * setPlusOne; + + y0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag * setPlusOne; + y0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag * setPlusOne; + y0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag * setPlusOne; + y0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag * setPlusOne; + y0c.imag += chi4.imag * a4c.real + chi4.real * a4c.imag * setPlusOne; + + *y0 = y0c; + + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + a4 += inca; + y0 += incy; + } + + } } From 906302588ff53e2e817ac71e584975dc1a67cebb Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 23 Dec 2021 04:44:24 -0600 Subject: [PATCH 15/63] Optimized daxpy2v implementation - Optimized axpy2v implementation for double datatype by handling rows in mulitple of 4 and store the final computed result at the end of computation, preventing unnecessary stores for improving the performance. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1973] Change-Id: I7b8ef94d0f67c1c666fdce26e9b2b7291365d2e9 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 8 +- config/zen3/bli_cntx_init_zen3.c | 4 +- kernels/zen/1f/CMakeLists.txt | 1 + kernels/zen/1f/bli_axpy2v_zen_int.c | 188 ++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 11 +- 6 files changed, 207 insertions(+), 9 deletions(-) create mode 100644 kernels/zen/1f/bli_axpy2v_zen_int.c diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 020e7052b9..ec356fd231 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,7 +80,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, @@ -89,6 +89,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + //axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 315362067e..47846ef22d 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,15 +92,17 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, - BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, + BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index ef47987454..7e7b120832 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 6, + 7, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -101,6 +101,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + // axpy2v + BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx ); diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index d2bf13822d..4b9caa40b6 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -7,4 +7,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_5.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c ) diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c new file mode 100644 index 0000000000..4ddca52162 --- /dev/null +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -0,0 +1,188 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2018, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ +#include "blis.h" +#include "immintrin.h" + + +/** + * daxpy2v kernel performs axpy2v operation. + * z := y + alphax * conjx(x) + alphay * conjy(y) + * where x, y, and z are vectors of length n. + */ +void bli_daxpy2v_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + double* restrict alphax, + double* restrict alphay, + double* restrict x, inc_t incx, + double* restrict y, inc_t incy, + double* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + if ( bli_zero_dim1( n ) ) return; + + if ( incz == 1 && incx == 1 && incy == 1 ) + { + dim_t i = 0; + dim_t rem = n%4; + const dim_t n_elem_per_reg = 4; + __m256d xv[4], yv[4], zv[4]; + __m256d alphaxv, alphayv; + + alphaxv = _mm256_broadcast_sd((double const*) alphax); + alphayv = _mm256_broadcast_sd((double const*) alphay); + + for( ; (i + 15) < n; i+= 16 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x + 3*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y + 3*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z + 1*n_elem_per_reg ); + zv[2] = _mm256_loadu_pd( z + 2*n_elem_per_reg ); + zv[3] = _mm256_loadu_pd( z + 3*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphaxv, zv[1]); + zv[2] = _mm256_fmadd_pd(xv[2], alphaxv, zv[2]); + zv[3] = _mm256_fmadd_pd(xv[3], alphaxv, zv[3]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + zv[1] = _mm256_fmadd_pd(yv[1], alphayv, zv[1]); + zv[2] = _mm256_fmadd_pd(yv[2], alphayv, zv[2]); + zv[3] = _mm256_fmadd_pd(yv[3], alphayv, zv[3]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + _mm256_storeu_pd((z + 1*n_elem_per_reg), zv[1]); + _mm256_storeu_pd((z + 2*n_elem_per_reg), zv[2]); + _mm256_storeu_pd((z + 3*n_elem_per_reg), zv[3]); + + z += 4*n_elem_per_reg; + x += 4*n_elem_per_reg; + y += 4*n_elem_per_reg; + } + + for( ; (i + 7) < n; i+= 8 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x + 1*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y + 1*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z + 1*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + zv[1] = _mm256_fmadd_pd(xv[1], alphaxv, zv[1]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + zv[1] = _mm256_fmadd_pd(yv[1], alphayv, zv[1]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + _mm256_storeu_pd((z + 1*n_elem_per_reg), zv[1]); + + z += 2*n_elem_per_reg; + x += 2*n_elem_per_reg; + y += 2*n_elem_per_reg; + } + + for( ; (i + 3) < n; i+= 4 ) + { + xv[0] = _mm256_loadu_pd( x + 0*n_elem_per_reg ); + + yv[0] = _mm256_loadu_pd( y + 0*n_elem_per_reg ); + + zv[0] = _mm256_loadu_pd( z + 0*n_elem_per_reg ); + + zv[0] = _mm256_fmadd_pd(xv[0], alphaxv, zv[0]); + + zv[0] = _mm256_fmadd_pd(yv[0], alphayv, zv[0]); + + _mm256_storeu_pd((z + 0*n_elem_per_reg), zv[0]); + + z += n_elem_per_reg; + x += n_elem_per_reg; + y += n_elem_per_reg; + } + if(rem) + { + PRAGMA_SIMD + for ( i = 0; i < rem; ++i ) + { + PASTEMAC(d,axpys)( *alphax, x[i], z[i] ); + PASTEMAC(d,axpys)( *alphay, y[i], z[i] ); + } + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(d,type); + PASTECH(d,axpyv_ker_ft) kfp_av + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_AXPYV_KER, cntx ); + + kfp_av + ( + conjx, + n, + alphax, + x, incx, + z, incz, + cntx + ); + + kfp_av + ( + conjy, + n, + alphay, + y, incy, + z, incz, + cntx + ); + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 42a92809c2..2caf5a9a92 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -34,12 +34,13 @@ */ // hemv helper function void bli_pre_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); void bli_post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); + double *y, double *alpha, + dim_t cs_a, dim_t rs_a); + // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) @@ -122,6 +123,8 @@ AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_5 ) AXPYF_KER_PROT( scomplex, c, axpyf_zen_int_4 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) +// axpy2v (intrinsics) +AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) From 4a1acbcf8fadd227e433db87096762d1c2a00c32 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 17 Dec 2021 02:34:52 -0600 Subject: [PATCH 16/63] Optimized dher2 implementation - Impplemented her2 framework calls for transposed and non transposed kernel variants. - dher2 kernel operate over 4 columns at a time. It computes 4x4 triangular part of matrix first and remainder part is computed in chunk of 4x4 tile upto m rows. - remainder cases(m < 4) are handled serially. AMD-Internal: [CPUPL-1968] Change-Id: I12ae97b2ad673a7fd9b733c607f27b1089142313 --- frame/2/hemv/bli_hemv_unf_var1.c | 12 +- frame/2/hemv/bli_hemv_unf_var3.c | 11 + frame/2/her2/bli_her2_unf_var1.c | 212 +++++++++++++++ frame/2/her2/bli_her2_unf_var4.c | 187 ++++++++++++++ kernels/zen/2/CMakeLists.txt | 2 +- kernels/zen/2/bli_her2_zen_int_4.c | 396 +++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 10 - 7 files changed, 816 insertions(+), 14 deletions(-) create mode 100644 kernels/zen/2/bli_her2_zen_int_4.c diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index ccb39b3485..6790e5bd08 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -218,9 +218,15 @@ void PASTEMAC(ch,varname) \ #ifdef BLIS_CONFIG_EPYC -void post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); +void bli_post_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); void bli_dhemv_unf_var1 ( diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index 6ed18efea4..abf08dfdaf 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -217,6 +217,17 @@ void PASTEMAC(ch,varname) \ } #ifdef BLIS_CONFIG_EPYC + +void bli_pre_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + void bli_dhemv_unf_var3 ( uplo_t uplo, diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index a0aec48f71..299e3d161d 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,5 +158,217 @@ void PASTEMAC(ch,varname) \ } \ } + +#ifdef BLIS_CONFIG_EPYC + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) +#endif diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index 3dea31d53e..e39c7224c4 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,5 +166,192 @@ void PASTEMAC(ch,varname) \ } \ } +#ifdef BLIS_CONFIG_EPYC +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) +#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) +#endif diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index dfa7c0b750..f20d114781 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -3,7 +3,7 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c ) diff --git a/kernels/zen/2/bli_her2_zen_int_4.c b/kernels/zen/2/bli_her2_zen_int_4.c new file mode 100644 index 0000000000..9b181aa278 --- /dev/null +++ b/kernels/zen/2/bli_her2_zen_int_4.c @@ -0,0 +1,396 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "immintrin.h" +#include "blis.h" + +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ) +{ + dim_t row = 0; + dim_t rem = m % 4; + + /*holds 4 diagonal elements of triangular part of 4x4 tile*/ + double a_diag[4] = {0}; + /*alpha_chi holds x*alpha and alpha_psi holds y*alpha*/ + double alpha_chi[4] = {0}; + double alpha_psi[4] = {0}; + /*Extracts diagonal element and store into a_diag buffer*/ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + a_diag[i] = *(a + m + i + (i * lda)); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + + __m256d xr, yr, zero, gamma; + __m256d a0, a1, a2, a3; + + zero = _mm256_setzero_pd(); + + /*Loading elements of x & y vectors*/ + x0 = _mm256_loadu_pd(x + m); + y0 = _mm256_loadu_pd(y + m); + /*Broadcasting alpha to compute alpha_psi and alpha_chi*/ + x1 = _mm256_broadcast_sd(alpha); + + x2 = _mm256_mul_pd(x0, x1); + y0 = _mm256_mul_pd(y0, x1); + + /*Storing alpha_chi and alpha_psi for later usage in computation loop*/ + _mm256_storeu_pd(alpha_chi, x2); + _mm256_storeu_pd(alpha_psi, y0); + + x0 = _mm256_mul_pd(x0, y0); + gamma = _mm256_loadu_pd(a_diag); + gamma = _mm256_add_pd(gamma, x0); + gamma = _mm256_add_pd(gamma, x0); + _mm256_storeu_pd(a_diag, gamma); + + /* Broadcasting 4 alpha_psis and alpha_chis which + * are to be used througout the computation of 4x4 tile + * upto m rows. + */ + x0 = _mm256_broadcast_sd(&alpha_chi[0]); + x1 = _mm256_broadcast_sd(&alpha_chi[1]); + x2 = _mm256_broadcast_sd(&alpha_chi[2]); + x3 = _mm256_broadcast_sd(&alpha_chi[3]); + + y0 = _mm256_broadcast_sd(&alpha_psi[0]); + y1 = _mm256_broadcast_sd(&alpha_psi[1]); + y2 = _mm256_broadcast_sd(&alpha_psi[2]); + y3 = _mm256_broadcast_sd(&alpha_psi[3]); + + /* Loading 4x4 tile of A matrix for + * triangular part computation + */ + a0 = _mm256_loadu_pd(a + (0 * lda) + m); + a1 = _mm256_loadu_pd(a + (1 * lda) + m); + a2 = _mm256_loadu_pd(a + (2 * lda) + m); + a3 = _mm256_loadu_pd(a + (3 * lda) + m); + + yr = _mm256_loadu_pd(y); + xr = _mm256_loadu_pd(x); + + /*Setting first element of x & y vectors to zero + * to eliminate diagonal element of 1st column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x1); + yr = _mm256_blend_pd(yr, zero, 0x1); + a0 = _mm256_blend_pd(a0, zero, 0x1); + + a1 = _mm256_blend_pd(a1, zero, 0x3); + a2 = _mm256_blend_pd(a2, zero, 0x7); + a3 = _mm256_blend_pd(a3, zero, 0xF); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a0 = _mm256_fmadd_pd(yr, x0, a0); + + /*Setting two elements of x & y vectors to zero + * to eliminate diagonal element of 2nd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x3); + yr = _mm256_blend_pd(yr, zero, 0x3); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a1 = _mm256_fmadd_pd(yr, x1, a1); + + /*Setting three elements of x & y vectors to zero + * to eliminate diagonal element of 3rd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x7); + yr = _mm256_blend_pd(yr, zero, 0x7); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a2 = _mm256_fmadd_pd(yr, x2, a2); + + _mm256_storeu_pd(a + (0 * lda) + m, a0 ); + + /* Loading data from memory location first + * so it could be blend with and finally + * gets stored at same location to prevent + * unnecessary data overwriting at nearby + * memory locations + */ + a3 = _mm256_loadu_pd(a + (1 * lda) + m ); + a1 = _mm256_blend_pd(a1, a3, 0x1); + _mm256_storeu_pd(a + (1 * lda) + m, a1 ); + + a3 = _mm256_loadu_pd(a + (2 * lda) + m ); + a2 = _mm256_blend_pd(a2, a3, 0x3); + _mm256_storeu_pd(a + (2 * lda) + m, a2 ); + + /* Triangular part of matrix is computed, remaining + * part is computed in below loop upto m rows. + */ + for(; (row + 4) <= m; row+=4) + { + /* Loading elements of x and y vector */ + xr = _mm256_loadu_pd(x + row); + yr = _mm256_loadu_pd(y + row); + /* Loading tile of A matrix of size 4x4 */ + a0 = _mm256_loadu_pd(a + row + (0 * lda) ); + a1 = _mm256_loadu_pd(a + row + (1 * lda) ); + a2 = _mm256_loadu_pd(a + row + (2 * lda) ); + a3 = _mm256_loadu_pd(a + row + (3 * lda) ); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a3 = _mm256_fmadd_pd(xr, y3, a3); + + a0 = _mm256_fmadd_pd(yr, x0, a0); + a1 = _mm256_fmadd_pd(yr, x1, a1); + a2 = _mm256_fmadd_pd(yr, x2, a2); + a3 = _mm256_fmadd_pd(yr, x3, a3); + + _mm256_storeu_pd(a + row + (0 * lda), a0); + _mm256_storeu_pd(a + row + (1 * lda), a1); + _mm256_storeu_pd(a + row + (2 * lda), a2); + _mm256_storeu_pd(a + row + (3 * lda), a3); + } + + /* Computes remainder cases where m is less than 4 */ + if(rem) + { + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + for(dim_t j = row; j < m; j++) + { + a[ j + (i * lda)] += x[j] * (y[i] * (*alpha)); + a[ j + (i * lda)] += y[j] * (x[i] * (*alpha)); + } + } + } + + /* Computing 4 diagonal elements of triangular part of matrix + * and storing result back at corresponding location in matrix A + */ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + *(a + m + i + (i * lda)) = a_diag[i]; + } +} + + +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ) +{ + dim_t row = 4; + dim_t rem = m % 4; + + /*holds 4 diagonal elements of triangular part of 4x4 tile*/ + double a_diag[4] = {0}; + /*alpha_chi holds x*alpha and alpha_psi holds y*alpha*/ + double alpha_chi[4] = {0}; + double alpha_psi[4] = {0}; + /*Extracts diagonal element and store into a_diag buffer*/ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + a_diag[i] = *(a + i + (i * lda)); + } + + __m256d x0, x1, x2, x3; + __m256d y0, y1, y2, y3; + + __m256d xr, yr, zero, gamma; + __m256d a0, a1, a2, a3; + + zero = _mm256_setzero_pd(); + + /*Loading elements of x & y vectors*/ + x0 = _mm256_loadu_pd(x); + y0 = _mm256_loadu_pd(y); + /*Broadcasting alpha to compute alpha_psi and alpha_chi*/ + x1 = _mm256_broadcast_sd(alpha); + + x2 = _mm256_mul_pd(x0, x1); + y0 = _mm256_mul_pd(y0, x1); + + /*Storing alpha_chi and alpha_psi for later usage in computation loop*/ + _mm256_storeu_pd(alpha_chi, x2); + _mm256_storeu_pd(alpha_psi, y0); + + x0 = _mm256_mul_pd(x0, y0); + gamma = _mm256_loadu_pd(a_diag); + gamma = _mm256_add_pd(gamma, x0); + gamma = _mm256_add_pd(gamma, x0); + _mm256_storeu_pd(a_diag, gamma); + + /* Broadcasting 4 alpha_psis and alpha_chis which + * are to be used througout the computation of 4x4 tile + * upto m rows. + */ + x0 = _mm256_broadcast_sd(&alpha_chi[0]); + x1 = _mm256_broadcast_sd(&alpha_chi[1]); + x2 = _mm256_broadcast_sd(&alpha_chi[2]); + x3 = _mm256_broadcast_sd(&alpha_chi[3]); + + y0 = _mm256_broadcast_sd(&alpha_psi[0]); + y1 = _mm256_broadcast_sd(&alpha_psi[1]); + y2 = _mm256_broadcast_sd(&alpha_psi[2]); + y3 = _mm256_broadcast_sd(&alpha_psi[3]); + + /* Loading 4x4 tile of A matrix for + * triangular part computation + */ + a0 = _mm256_loadu_pd(a + (0 * lda) ); + a1 = _mm256_loadu_pd(a + (1 * lda) ); + a2 = _mm256_loadu_pd(a + (2 * lda) ); + a3 = _mm256_loadu_pd(a + (3 * lda) ); + + yr = _mm256_loadu_pd(y); + xr = _mm256_loadu_pd(x); + + /*Setting first element of x & y vectors to zero + * to eliminate diagonal element of 1st column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x1); + yr = _mm256_blend_pd(yr, zero, 0x1); + a0 = _mm256_blend_pd(a0, zero, 0x1); + a1 = _mm256_blend_pd(a1, zero, 0x3); + a2 = _mm256_blend_pd(a2, zero, 0x7); + a3 = _mm256_blend_pd(a3, zero, 0xF); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a0 = _mm256_fmadd_pd(yr, x0, a0); + + /*Setting two elements of x & y vectors to zero + * to eliminate diagonal element of 2nd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x3); + yr = _mm256_blend_pd(yr, zero, 0x3); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a1 = _mm256_fmadd_pd(yr, x1, a1); + + /*Setting three elements of x & y vectors to zero + * to eliminate diagonal element of 3rd column + * from computation + */ + xr = _mm256_blend_pd(xr, zero, 0x7); + yr = _mm256_blend_pd(yr, zero, 0x7); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a2 = _mm256_fmadd_pd(yr, x2, a2); + + _mm256_storeu_pd(a + (0 * lda), a0 ); + + /* Loading data from memory location first + * so it could be blend with and finally + * gets stored at same location to prevent + * unnecessary data overwriting at nearby + * memory locations + */ + a3 = _mm256_loadu_pd(a + (1 * lda) ); + a1 = _mm256_blend_pd(a1, a3, 0x1); + _mm256_storeu_pd(a + (1 * lda), a1 ); + + a3 = _mm256_loadu_pd(a + (2 * lda) ); + a2 = _mm256_blend_pd(a2, a3, 0x3); + _mm256_storeu_pd(a + (2 * lda), a2 ); + + /* Triangular part of matrix is computed, remaining + * part is computed in below loop upto m rows. + */ + for(; (row + 4) <= m; row+=4) + { + /* Loading elements of x and y vector */ + xr = _mm256_loadu_pd(x + row); + yr = _mm256_loadu_pd(y + row); + /* Loading tile of A matrix of size 4x4 */ + a0 = _mm256_loadu_pd(a + row + (0 * lda) ); + a1 = _mm256_loadu_pd(a + row + (1 * lda) ); + a2 = _mm256_loadu_pd(a + row + (2 * lda) ); + a3 = _mm256_loadu_pd(a + row + (3 * lda) ); + + a0 = _mm256_fmadd_pd(xr, y0, a0); + a1 = _mm256_fmadd_pd(xr, y1, a1); + a2 = _mm256_fmadd_pd(xr, y2, a2); + a3 = _mm256_fmadd_pd(xr, y3, a3); + + a0 = _mm256_fmadd_pd(yr, x0, a0); + a1 = _mm256_fmadd_pd(yr, x1, a1); + a2 = _mm256_fmadd_pd(yr, x2, a2); + a3 = _mm256_fmadd_pd(yr, x3, a3); + + _mm256_storeu_pd(a + row + (0 * lda), a0); + _mm256_storeu_pd(a + row + (1 * lda), a1); + _mm256_storeu_pd(a + row + (2 * lda), a2); + _mm256_storeu_pd(a + row + (3 * lda), a3); + } + + /* Computes remainder cases where m is less than 4 */ + if(rem) + { + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + for(dim_t j = row; j < m; j++) + { + a[ j + (i * lda)] += x[j] * (y[i] * (*alpha)); + a[ j + (i * lda)] += y[j] * (x[i] * (*alpha)); + } + } + } + + /* Computing 4 diagonal elements of triangular part of matrix + * and storing result back at corresponding location in matrix A + */ + PRAGMA_SIMD + for(dim_t i = 0; i < 4; i++) + { + *(a + i + (i * lda)) = a_diag[i]; + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 2caf5a9a92..7edc0a9a1a 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -32,16 +32,6 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ -// hemv helper function -void bli_pre_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); - -void bli_post_hemv_8x8(double *a, double *x, - double *y, double *alpha, - dim_t cs_a, dim_t rs_a); - - // -- level-1m -- PACKM_KER_PROT(double, d, packm_8xk_gen_zen) PACKM_KER_PROT(double, d, packm_6xk_gen_zen) From 2ef8d7a7d319b5b9e58c507161c640eb46b23704 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 21 Dec 2021 15:08:16 +0530 Subject: [PATCH 17/63] Eliminating barriers in SUP path when matrices are not packed. -Current gemm SUP path uses bli_thrinfo_sup_grow, bli_thread_range_sub to generate per thread data ranges at each loop of gemm algorithm. bli_thrinfo_sup_grow involves usage of multiple barriers for cross thread synchronization. These barriers are necessary in cases where either the A or B matrix are packed for centralized pack buffer allocation/deallocation (bli_thread_am_ochief thread). -However for cases where both A and B matrices are unpacked, these barrier are resulting in overhead for smaller dimensions. Here creation of unnecessary communicators are avoided and subsequently the requirement for barriers are eliminated when packing is disabled for both the input matrices in SUP path. Change-Id: Ic373dfd2d6b08b8f577dc98399a83bb08f794afa --- frame/thread/bli_thrinfo_sup.c | 126 ++++++++++++++++++++++----------- 1 file changed, 83 insertions(+), 43 deletions(-) diff --git a/frame/thread/bli_thrinfo_sup.c b/frame/thread/bli_thrinfo_sup.c index e67e8b6426..8ce714547c 100644 --- a/frame/thread/bli_thrinfo_sup.c +++ b/frame/thread/bli_thrinfo_sup.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -167,8 +167,23 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl thrcomm_t* static_comms[ BLIS_NUM_STATIC_COMMS ]; thrcomm_t** new_comms = NULL; + + const bool packa = bli_rntm_pack_a( rntm ); + const bool packb = bli_rntm_pack_b( rntm ); + dim_t parent_nt_in = 0; + + // thrinfo ocomm is not created when neither packa nor packb is + // enabled. Need to derive parent_nt_in without depending on ocomm in + // those cases. + if ( packa || packb ) + { + parent_nt_in = bli_thread_num_threads( thread_par ); + } + else + { + parent_nt_in = bli_rntm_calc_num_threads_in( bszid_par, rntm ); + } - const dim_t parent_nt_in = bli_thread_num_threads( thread_par ); const dim_t parent_n_way = bli_thread_n_way( thread_par ); const dim_t parent_comm_id = bli_thread_ocomm_id( thread_par ); const dim_t parent_work_id = bli_thread_work_id( thread_par ); @@ -193,50 +208,75 @@ thrinfo_t* bli_thrinfo_sup_create_for_cntl //printf( "thread %d: child_n_way = %d child_nt_in = %d parent_n_way = %d (bszid = %d->%d)\n", (int)child_comm_id, (int)child_nt_in, (int)child_n_way, (int)parent_n_way, (int)bli_cntl_bszid( cntl_par ), (int)bszid_chl ); - // The parent's chief thread creates a temporary array of thrcomm_t - // pointers. - if ( bli_thread_am_ochief( thread_par ) ) + thrinfo_t* thread_chl = NULL; + + // The communicators are only used when either packa or packb is + // enabled. This means that the communicator creation along with the + // overhead from the barriers (required for synchronizing comm across + // threads) are not required when both packa and packb are disabled. + if ( packa || packb ) { - if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); - else - new_comms = static_comms; - } + // The parent's chief thread creates a temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + new_comms = bli_malloc_intl( parent_n_way * sizeof( thrcomm_t* ) ); + else + new_comms = static_comms; + } + + // Broadcast the temporary array to all threads in the parent's + // communicator. + new_comms = bli_thread_broadcast( thread_par, new_comms ); + + // Chiefs in the child communicator allocate the communicator + // object and store it in the array element corresponding to the + // parent's work id. + if ( child_comm_id == 0 ) + new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); + + bli_thread_barrier( thread_par ); + + // All threads create a new thrinfo_t node using the communicator + // that was created by their chief, as identified by parent_work_id. + thread_chl = bli_thrinfo_create + ( + rntm, // rntm + new_comms[ parent_work_id ], // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + TRUE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); + + bli_thread_barrier( thread_par ); - // Broadcast the temporary array to all threads in the parent's - // communicator. - new_comms = bli_thread_broadcast( thread_par, new_comms ); - - // Chiefs in the child communicator allocate the communicator - // object and store it in the array element corresponding to the - // parent's work id. - if ( child_comm_id == 0 ) - new_comms[ parent_work_id ] = bli_thrcomm_create( rntm, child_nt_in ); - - bli_thread_barrier( thread_par ); - - // All threads create a new thrinfo_t node using the communicator - // that was created by their chief, as identified by parent_work_id. - thrinfo_t* thread_chl = bli_thrinfo_create - ( - rntm, // rntm - new_comms[ parent_work_id ], // ocomm - child_comm_id, // ocomm_id - child_n_way, // n_way - child_work_id, // work_id - TRUE, // free_comm - *bszid_chl, // bszid - NULL // sub_node - ); - - bli_thread_barrier( thread_par ); - - // The parent's chief thread frees the temporary array of thrcomm_t - // pointers. - if ( bli_thread_am_ochief( thread_par ) ) + // The parent's chief thread frees the temporary array of thrcomm_t + // pointers. + if ( bli_thread_am_ochief( thread_par ) ) + { + if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) + bli_free_intl( new_comms ); + } + } + else { - if ( parent_n_way > BLIS_NUM_STATIC_COMMS ) - bli_free_intl( new_comms ); + // No communicator is reqiured in cases where neither packa nor + // packb is enabled. + thread_chl = bli_thrinfo_create + ( + rntm, // rntm + NULL, // ocomm + child_comm_id, // ocomm_id + child_n_way, // n_way + child_work_id, // work_id + FALSE, // free_comm + *bszid_chl, // bszid + NULL // sub_node + ); } return thread_chl; From 34e1091031ea1eef1e70a817b5cef54d40f75cef Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 30 Dec 2021 22:59:39 -0600 Subject: [PATCH 18/63] Implemented optimal S/DCOMPLEX dotxf kernel - Optimized dotxf implementation for double and single precision complex datatype by handling dot product computation in tile 2x6 and 4x6 handling 6 columns at a time, and rows in multiple of 2 and 4. - Dot product computation is arranged such a way that multiple rho vector register will hold the temporary result till the end of loop and finally does horizontal addition to get final dot product result. - Corner cases are handled serially. - Optimal and reuse of vector registers for faster computation. AMD-Internal: [CPUPL-1975] Change-Id: I7dd305e73adf54100d54661769c7d5aada9b0098 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- kernels/zen/1f/bli_dotxf_zen_int_8.c | 902 ++++++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 3 +- 5 files changed, 912 insertions(+), 5 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index ec356fd231..eed39b3149 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,7 +80,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, @@ -89,6 +89,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, //axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 47846ef22d..f6b8eef1e4 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -101,6 +101,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 7e7b120832..a043d5ad22 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 7, + 9, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -101,6 +101,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, + BLIS_DOTXF_KER, BLIS_DCOMPLEX, bli_zdotxf_zen_int_6, + BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, cntx diff --git a/kernels/zen/1f/bli_dotxf_zen_int_8.c b/kernels/zen/1f/bli_dotxf_zen_int_8.c index ad27403bdc..815e388f21 100644 --- a/kernels/zen/1f/bli_dotxf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2017 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -1542,4 +1542,904 @@ void bli_ddotxf_zen_int_2 } } +/** + * Performs dotxf operation on dcomplex. + * x and y are vectors and a is the matrix. + * Computation is done on 6 columns at a time + * Marches through vectors in multiple of 2. + */ +void bli_zdotxf_zen_int_6 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + /** + * Handles only unit stride cases and 6 column at a time + * b_n check for columns to be 6. + */ + if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + { + /* Temporary rho buffer holds computed dot product result */ + dcomplex r[ 6 ]; + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early*/ + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + /* Initialize r vector to 0. */ + for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(z,set0s)( r[i] ); + + /* If a must be conjugated, we do so indirectly by first + * toggling the effective conjugation of x and then conjugating + * the resulting do products. + * Rather conjugating each element of a matrix, final computed result + * can be conjugated at the end of loop. This takes off the overhead + * of conjugating each element inside the loop and improves the + * performance. + */ + conj_t conjx_use = conjx; + + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjx_use ); + } + + /* Setting rho vectors to 0 */ + v4df_t rho0v; rho0v.v = _mm256_setzero_pd(); + v4df_t rho1v; rho1v.v = _mm256_setzero_pd(); + v4df_t rho2v; rho2v.v = _mm256_setzero_pd(); + v4df_t rho3v; rho3v.v = _mm256_setzero_pd(); + v4df_t rho4v; rho4v.v = _mm256_setzero_pd(); + v4df_t rho5v; rho5v.v = _mm256_setzero_pd(); + + v4df_t rho6v; rho6v.v = _mm256_setzero_pd(); + v4df_t rho7v; rho7v.v = _mm256_setzero_pd(); + v4df_t rho8v; rho8v.v = _mm256_setzero_pd(); + v4df_t rho9v; rho9v.v = _mm256_setzero_pd(); + v4df_t rho10v; rho10v.v = _mm256_setzero_pd(); + v4df_t rho11v; rho11v.v = _mm256_setzero_pd(); + + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + v4df_t x0v, x1v; + /* Holds 2x6 tile of matrix A */ + v4df_t a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m256d no_conju = _mm256_setr_pd(-1, 1, -1, 1); + __m256d conju = _mm256_setr_pd(1, -1, 1, -1); + dim_t iter = m / 2; + dim_t rem = m % 2; + dim_t i = 0; + + if ( bli_is_noconj( conjx_use ) ) + { + if(iter) + { + for ( ; (i+1) < m; i+=2) + { + /*Load 2 dcomplex elements from + * vector x + */ + x0v.v = _mm256_loadu_pd( + (double *)(x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + * It will do following operation. + * R0 I0 R1 I1 => I0 I0 I1 I1 + * + */ + x1v.v = _mm256_permute_pd( x0v.v, 15 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + * It will do following operation. + * R0 I0 R1 I1 => R0 R0 R1 R1 + */ + x0v.v = _mm256_permute_pd( x0v.v, 0 ); + + /*Load 2x6 tile of matrix A*/ + a0v.v = _mm256_loadu_pd( (double *) + (a + i + 0 * lda) ); + a1v.v = _mm256_loadu_pd( (double *) + (a + i + 1 * lda) ); + a2v.v = _mm256_loadu_pd( (double *) + (a + i + 2 * lda) ); + a3v.v = _mm256_loadu_pd( (double *) + (a + i + 3 * lda) ); + a4v.v = _mm256_loadu_pd( (double *) + (a + i + 4 * lda) ); + a5v.v = _mm256_loadu_pd( (double *) + (a + i + 5 * lda) ); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, + x0v.v, rho0v.v ); + rho6v.v = _mm256_fmadd_pd( a0v.v, + x1v.v, rho6v.v ); + + rho1v.v = _mm256_fmadd_pd( a1v.v, + x0v.v, rho1v.v ); + rho7v.v = _mm256_fmadd_pd( a1v.v, + x1v.v, rho7v.v ); + + rho2v.v = _mm256_fmadd_pd( a2v.v, + x0v.v, rho2v.v ); + rho8v.v = _mm256_fmadd_pd( a2v.v, + x1v.v, rho8v.v ); + + rho3v.v = _mm256_fmadd_pd( a3v.v, + x0v.v, rho3v.v ); + rho9v.v = _mm256_fmadd_pd( a3v.v, + x1v.v, rho9v.v ); + + rho4v.v = _mm256_fmadd_pd( a4v.v, + x0v.v, rho4v.v ); + rho10v.v = _mm256_fmadd_pd( a4v.v, + x1v.v, rho10v.v ); + + rho5v.v = _mm256_fmadd_pd( a5v.v, + x0v.v, rho5v.v ); + rho11v.v = _mm256_fmadd_pd( a5v.v, + x1v.v, rho11v.v ); + } + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); + rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); + rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); + rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); + rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); + rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication + */ + rho6v.v = _mm256_mul_pd(rho6v.v, no_conju); + rho7v.v = _mm256_mul_pd(rho7v.v, no_conju); + rho8v.v = _mm256_mul_pd(rho8v.v, no_conju); + rho9v.v = _mm256_mul_pd(rho9v.v, no_conju); + rho10v.v = _mm256_mul_pd(rho10v.v, no_conju); + rho11v.v = _mm256_mul_pd(rho11v.v, no_conju); + + rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); + rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); + rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); + rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); + rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); + rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rho0v.d[0] += rho0v.d[2]; + rho0v.d[1] += rho0v.d[3]; + + rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; + rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + + rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; + rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + + rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; + rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + + rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; + rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + + rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; + rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + + /*Computed dot product result is being stored + * in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)r, rho0v.v); + _mm256_storeu_pd((double *)(r+2) , rho1v.v); + _mm256_storeu_pd((double *)(r+4) , rho2v.v); + + } + /*handles remainder cases*/ + if(rem) + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(z,axpys)( a[i + p*lda] + , x[i], r[p] ); + } + } + } + else + { + if(iter) + { + for ( ; (i+1) < m; i+=2) + { + /*Load 2 dcomplex elements from + * vector x + */ + x0v.v = _mm256_loadu_pd( (double *) + (x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v.v = _mm256_permute_pd( x0v.v, 15 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v.v = _mm256_permute_pd( x0v.v, 0 ); + + /*Load 2x6 tile of matrix A*/ + a0v.v = _mm256_loadu_pd( (double *) + (a + i + 0 * lda)); + a1v.v = _mm256_loadu_pd( (double *) + (a + i + 1 * lda)); + a2v.v = _mm256_loadu_pd( (double *) + (a + i + 2 * lda)); + a3v.v = _mm256_loadu_pd( (double *) + (a + i + 3 * lda)); + a4v.v = _mm256_loadu_pd( (double *) + (a + i + 4 * lda)); + a5v.v = _mm256_loadu_pd( (double *) + (a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + rho0v.v = _mm256_fmadd_pd( a0v.v, + x0v.v, rho0v.v ); + rho6v.v = _mm256_fmadd_pd( a0v.v, + x1v.v, rho6v.v ); + + rho1v.v = _mm256_fmadd_pd( a1v.v, + x0v.v, rho1v.v ); + rho7v.v = _mm256_fmadd_pd( a1v.v, + x1v.v, rho7v.v ); + + rho2v.v = _mm256_fmadd_pd( a2v.v, + x0v.v, rho2v.v ); + rho8v.v = _mm256_fmadd_pd( a2v.v, + x1v.v, rho8v.v ); + + rho3v.v = _mm256_fmadd_pd( a3v.v, + x0v.v, rho3v.v ); + rho9v.v = _mm256_fmadd_pd( a3v.v, + x1v.v, rho9v.v ); + + rho4v.v = _mm256_fmadd_pd( a4v.v, + x0v.v, rho4v.v ); + rho10v.v = _mm256_fmadd_pd( a4v.v, + x1v.v, rho10v.v ); + + rho5v.v = _mm256_fmadd_pd( a5v.v, + x0v.v, rho5v.v ); + rho11v.v = _mm256_fmadd_pd( a5v.v, + x1v.v, rho11v.v ); + } + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v.v = _mm256_permute_pd(rho6v.v, 0x05); + rho7v.v = _mm256_permute_pd(rho7v.v, 0x05); + rho8v.v = _mm256_permute_pd(rho8v.v, 0x05); + rho9v.v = _mm256_permute_pd(rho9v.v, 0x05); + rho10v.v = _mm256_permute_pd(rho10v.v, 0x05); + rho11v.v = _mm256_permute_pd(rho11v.v, 0x05); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication + */ + rho6v.v = _mm256_mul_pd(rho6v.v, conju); + rho7v.v = _mm256_mul_pd(rho7v.v, conju); + rho8v.v = _mm256_mul_pd(rho8v.v, conju); + rho9v.v = _mm256_mul_pd(rho9v.v, conju); + rho10v.v = _mm256_mul_pd(rho10v.v, conju); + rho11v.v = _mm256_mul_pd(rho11v.v, conju); + + rho0v.v = _mm256_add_pd(rho0v.v, rho6v.v); + rho1v.v = _mm256_add_pd(rho1v.v, rho7v.v); + rho2v.v = _mm256_add_pd(rho2v.v, rho8v.v); + rho3v.v = _mm256_add_pd(rho3v.v, rho9v.v); + rho4v.v = _mm256_add_pd(rho4v.v, rho10v.v); + rho5v.v = _mm256_add_pd(rho5v.v, rho11v.v); + + /*rho0, rho1, rho2 holds final dot product + * result of 6 dcomplex elements. + */ + rho0v.d[0] += rho0v.d[2]; + rho0v.d[1] += rho0v.d[3]; + + rho0v.d[2] = rho1v.d[0] + rho1v.d[2]; + rho0v.d[3] = rho1v.d[1] + rho1v.d[3]; + + rho1v.d[0] = rho2v.d[0] + rho2v.d[2]; + rho1v.d[1] = rho2v.d[1] + rho2v.d[3]; + + rho1v.d[2] = rho3v.d[0] + rho3v.d[2]; + rho1v.d[3] = rho3v.d[1] + rho3v.d[3]; + + rho2v.d[0] = rho4v.d[0] + rho4v.d[2]; + rho2v.d[1] = rho4v.d[1] + rho4v.d[3]; + + rho2v.d[2] = rho5v.d[0] + rho5v.d[2]; + rho2v.d[3] = rho5v.d[1] + rho5v.d[3]; + + /*Computed dot product result is being stored + * in temp buffer r for further computation. + */ + _mm256_storeu_pd((double *)r, rho0v.v); + _mm256_storeu_pd((double *)(r+2) , rho1v.v); + _mm256_storeu_pd((double *)(r+4) , rho2v.v); + + } + if(rem) + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(z,axpyjs)(a[i + p*lda] + , x[i], r[p] ); + } + } + } + + if ( bli_is_conj( conjat ) ) + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,conjs)( r[i] ); + } + + /*scaling dot product result with alpha and + * adding the result to vector + */ + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(z,axpys)( *alpha, r[i], y[i] ); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(z,type); + PASTECH(z,dotxv_ker_ft) kfp_dv + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + dcomplex* restrict a1 = a + (0 )*inca + (i )*lda; + dcomplex* restrict x1 = x + (0 )*incx; + dcomplex* restrict psi1 = y + (i )*incy; + + kfp_dv + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + } + +} + + +/** + * Performs dotxf operation on scomplex. + * x and y are vectors and a is the matrix. + * Computation is done on 6 columns at a time + * Marches through vectors in multiple of 4 and 2. + */ +void bli_cdotxf_zen_int_6 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + cntx_t* restrict cntx + ) +{ + if ( (inca == 1) && (incx == 1) && (incy == 1) && (b_n == 6) ) + { + /* Temporary rho buffer holds computed dot product result */ + scomplex r[ 6 ]; + + /* If beta is zero, clear y. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,scals)( *beta, y[i] ); + } + } + + /* If the vectors are empty or if alpha is zero, return early. */ + if ( bli_zero_dim1( m ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + /* Initialize r vector to 0. */ + for ( dim_t i = 0; i < 6; ++i ) PASTEMAC(c,set0s)( r[i] ); + + /* If a must be conjugated, we do so indirectly by first toggling the + effective conjugation of x and then conjugating the resulting do + products. */ + conj_t conjx_use = conjx; + + if ( bli_is_conj( conjat ) ) + bli_toggle_conj( &conjx_use ); + + dim_t iter = m / 2; + dim_t iter4 = m / 4; + dim_t rem = m % 2; + dim_t i = 0; + if(iter) + { + if(iter4) + { + /* Setting rho vectors to 0 */ + __m256 rho0v; rho0v = _mm256_setzero_ps(); + __m256 rho1v; rho1v = _mm256_setzero_ps(); + __m256 rho2v; rho2v = _mm256_setzero_ps(); + __m256 rho3v; rho3v = _mm256_setzero_ps(); + __m256 rho4v; rho4v = _mm256_setzero_ps(); + __m256 rho5v; rho5v = _mm256_setzero_ps(); + + __m256 rho6v; rho6v = _mm256_setzero_ps(); + __m256 rho7v; rho7v = _mm256_setzero_ps(); + __m256 rho8v; rho8v = _mm256_setzero_ps(); + __m256 rho9v; rho9v = _mm256_setzero_ps(); + __m256 rho10v; rho10v = _mm256_setzero_ps(); + __m256 rho11v; rho11v = _mm256_setzero_ps(); + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + __m256 x0v, x1v; + /* Holds 2x6 tile of matrix A */ + __m256 a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m256 no_conju = _mm256_setr_ps(-1, 1, -1, 1, -1, 1, -1, 1); + __m256 conju = _mm256_setr_ps(1, -1, 1, -1, 1, -1, 1, -1); + + // March through vectos in multiple of 4. + for ( ; (i+3) < m; i+=4) + { + /*Load 4 scomplex elements from vector x*/ + x0v = _mm256_loadu_ps( (float *) (x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v = _mm256_permute_ps( x0v, 0xf5 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v = _mm256_permute_ps( x0v, 0xa0); + /* x1v.v holds imag part of dcomplex + Load 4x6 tile of matrix A*/ + a0v = _mm256_loadu_ps( (float *)(a + i + 0 * lda)); + a1v = _mm256_loadu_ps( (float *)(a + i + 1 * lda)); + a2v = _mm256_loadu_ps( (float *)(a + i + 2 * lda)); + a3v = _mm256_loadu_ps( (float *)(a + i + 3 * lda)); + a4v = _mm256_loadu_ps( (float *)(a + i + 4 * lda)); + a5v = _mm256_loadu_ps( (float *)(a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + + rho0v = _mm256_fmadd_ps( a0v, x0v, rho0v ); + rho6v = _mm256_fmadd_ps( a0v, x1v, rho6v ); + + rho1v = _mm256_fmadd_ps( a1v, x0v, rho1v ); + rho7v = _mm256_fmadd_ps( a1v, x1v, rho7v ); + + rho2v = _mm256_fmadd_ps( a2v, x0v, rho2v ); + rho8v = _mm256_fmadd_ps( a2v, x1v, rho8v ); + + rho3v = _mm256_fmadd_ps( a3v, x0v, rho3v ); + rho9v = _mm256_fmadd_ps( a3v, x1v, rho9v ); + + rho4v = _mm256_fmadd_ps( a4v, x0v, rho4v ); + rho10v = _mm256_fmadd_ps( a4v, x1v, rho10v ); + + rho5v = _mm256_fmadd_ps( a5v, x0v, rho5v ); + rho11v = _mm256_fmadd_ps( a5v, x1v, rho11v ); + } + + + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + + rho6v = _mm256_permute_ps(rho6v, 0xb1); + rho7v = _mm256_permute_ps(rho7v, 0xb1); + rho8v = _mm256_permute_ps(rho8v, 0xb1); + rho9v = _mm256_permute_ps(rho9v, 0xb1); + rho10v = _mm256_permute_ps(rho10v, 0xb1); + rho11v = _mm256_permute_ps(rho11v, 0xb1); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication*/ + if ( bli_is_noconj( conjx_use ) ) + { + rho6v = _mm256_mul_ps(rho6v, no_conju); + rho7v = _mm256_mul_ps(rho7v, no_conju); + rho8v = _mm256_mul_ps(rho8v, no_conju); + rho9v = _mm256_mul_ps(rho9v, no_conju); + rho10v = _mm256_mul_ps(rho10v, no_conju); + rho11v = _mm256_mul_ps(rho11v, no_conju); + } + else + { + + rho6v = _mm256_mul_ps(rho6v, conju); + rho7v = _mm256_mul_ps(rho7v, conju); + rho8v = _mm256_mul_ps(rho8v, conju); + rho9v = _mm256_mul_ps(rho9v, conju); + rho10v = _mm256_mul_ps(rho10v, conju); + rho11v = _mm256_mul_ps(rho11v, conju); + + } + + rho0v = _mm256_add_ps(rho0v, rho6v); + rho1v = _mm256_add_ps(rho1v, rho7v); + rho2v = _mm256_add_ps(rho2v, rho8v); + rho3v = _mm256_add_ps(rho3v, rho9v); + rho4v = _mm256_add_ps(rho4v, rho10v); + rho5v = _mm256_add_ps(rho5v, rho11v); + + /** + * Horizontal addition of rho elements + * for computing final dotxf result. + * ptr pointer addresses all 6 rho + * register one by one and store the + * computed result into r buffer. + */ + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t i = 0; i < 4; i++) + { + r[0].real += ptr[i].real; + r[0].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t i = 0; i < 4; i++) + { + r[1].real += ptr[i].real; + r[1].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t i = 0; i < 4; i++) + { + r[2].real += ptr[i].real; + r[2].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t i = 0; i < 4; i++) + { + r[3].real += ptr[i].real; + r[3].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho4v; + for(dim_t i = 0; i < 4; i++) + { + r[4].real += ptr[i].real; + r[4].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho5v; + for(dim_t i = 0; i < 4; i++) + { + r[5].real += ptr[i].real; + r[5].imag += ptr[i].imag; + } + } + // March through vectos in multiple of 2. + if(i+1 < m) + { + /* Setting rho vectors to 0 */ + __m128 rho0v; rho0v = _mm_setzero_ps(); + __m128 rho1v; rho1v = _mm_setzero_ps(); + __m128 rho2v; rho2v = _mm_setzero_ps(); + __m128 rho3v; rho3v = _mm_setzero_ps(); + __m128 rho4v; rho4v = _mm_setzero_ps(); + __m128 rho5v; rho5v = _mm_setzero_ps(); + + __m128 rho6v; rho6v = _mm_setzero_ps(); + __m128 rho7v; rho7v = _mm_setzero_ps(); + __m128 rho8v; rho8v = _mm_setzero_ps(); + __m128 rho9v; rho9v = _mm_setzero_ps(); + __m128 rho10v; rho10v = _mm_setzero_ps(); + __m128 rho11v; rho11v = _mm_setzero_ps(); + /* Holds 2 dcomplex element of x vector + * for computing dot product with A tile + */ + __m128 x0v, x1v; + /* Holds 2x6 tile of matrix A */ + __m128 a0v, a1v, a2v, a3v, a4v, a5v; + /** + * Since complex datatype multiplication is + * being held in two sets of rho vectors. + * Where first set holds the computaion with + * real part of vector x and other holds + * imaginary part of vector x. + * For final computation, based on conj sign + * of imaginary component needs to be toggled. + */ + __m128 no_conju = _mm_setr_ps(-1, 1, -1, 1); + __m128 conju = _mm_setr_ps(1, -1, 1, -1); + + for ( ; (i+1) < m; i+=2) + { + /*Load 4 scomplex elements from vector x*/ + x0v = _mm_loadu_ps( (float *)(x + i) ); + /* x1v.v holds imaginary part of dcomplex + * elements from vector x + */ + x1v = _mm_permute_ps( x0v, 0xf5 ); + /* x1v.v holds real part of dcomplex + * elements from vector x + */ + x0v = _mm_permute_ps( x0v, 0xa0); + /* x1v.v holds imag part of dcomplex + Load 4x6 tile of matrix A*/ + + a0v = _mm_loadu_ps( (float *)(a + i + 0 * lda)); + a1v = _mm_loadu_ps( (float *)(a + i + 1 * lda)); + a2v = _mm_loadu_ps( (float *)(a + i + 2 * lda)); + a3v = _mm_loadu_ps( (float *)(a + i + 3 * lda)); + a4v = _mm_loadu_ps( (float *)(a + i + 4 * lda)); + a5v = _mm_loadu_ps( (float *)(a + i + 5 * lda)); + + // perform: rho?v += a?v * x0v; + + rho0v = _mm_fmadd_ps( a0v, x0v, rho0v ); + rho6v = _mm_fmadd_ps( a0v, x1v, rho6v ); + + rho1v = _mm_fmadd_ps( a1v, x0v, rho1v ); + rho7v = _mm_fmadd_ps( a1v, x1v, rho7v ); + + rho2v = _mm_fmadd_ps( a2v, x0v, rho2v ); + rho8v = _mm_fmadd_ps( a2v, x1v, rho8v ); + + rho3v = _mm_fmadd_ps( a3v, x0v, rho3v ); + rho9v = _mm_fmadd_ps( a3v, x1v, rho9v ); + + rho4v = _mm_fmadd_ps( a4v, x0v, rho4v ); + rho10v = _mm_fmadd_ps( a4v, x1v, rho10v ); + + rho5v = _mm_fmadd_ps( a5v, x0v, rho5v ); + rho11v = _mm_fmadd_ps( a5v, x1v, rho11v ); + } + /*Swapping position of real and imag component + * for horizontal addition to get the final + * dot product computation + * rho register are holding computation which needs + * to be arranged in following manner. + * Ra0*Ix0 | Ia0*Ix0 | Ra1*Ix1 | Ia1*Ix1 + * || + * \/ + * Ia0*Ix0 | Ra0*Ix0 | Ia1*Ix1 | Ra1*Ix1 + */ + rho6v = _mm_permute_ps(rho6v, 0xb1); + rho7v = _mm_permute_ps(rho7v, 0xb1); + rho8v = _mm_permute_ps(rho8v, 0xb1); + rho9v = _mm_permute_ps(rho9v, 0xb1); + rho10v = _mm_permute_ps(rho10v, 0xb1); + rho11v = _mm_permute_ps(rho11v, 0xb1); + + /*Negating imaginary part for computing + * the final result of dcomplex multiplication*/ + if ( bli_is_noconj( conjx_use ) ) + { + + rho6v = _mm_mul_ps(rho6v, no_conju); + rho7v = _mm_mul_ps(rho7v, no_conju); + rho8v = _mm_mul_ps(rho8v, no_conju); + rho9v = _mm_mul_ps(rho9v, no_conju); + rho10v = _mm_mul_ps(rho10v, no_conju); + rho11v = _mm_mul_ps(rho11v, no_conju); + } + else + { + rho6v = _mm_mul_ps(rho6v, conju); + rho7v = _mm_mul_ps(rho7v, conju); + rho8v = _mm_mul_ps(rho8v, conju); + rho9v = _mm_mul_ps(rho9v, conju); + rho10v = _mm_mul_ps(rho10v, conju); + rho11v = _mm_mul_ps(rho11v, conju); + } + + rho0v = _mm_add_ps(rho0v, rho6v); + rho1v = _mm_add_ps(rho1v, rho7v); + rho2v = _mm_add_ps(rho2v, rho8v); + rho3v = _mm_add_ps(rho3v, rho9v); + rho4v = _mm_add_ps(rho4v, rho10v); + rho5v = _mm_add_ps(rho5v, rho11v); + + /** + * Horizontal addition of rho elements + * for computing final dotxf result. + * ptr pointer addresses all 6 rho + * register one by one and store the + * computed result into r buffer. + */ + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t i = 0; i < 2; i++) + { + r[0].real += ptr[i].real; + r[0].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t i = 0; i < 2; i++) + { + r[1].real += ptr[i].real; + r[1].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t i = 0; i < 2; i++) + { + r[2].real += ptr[i].real; + r[2].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t i = 0; i < 2; i++) + { + r[3].real += ptr[i].real; + r[3].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho4v; + for(dim_t i = 0; i < 2; i++) + { + r[4].real += ptr[i].real; + r[4].imag += ptr[i].imag; + } + ptr = (scomplex *)&rho5v; + for(dim_t i = 0; i < 2; i++) + { + r[5].real += ptr[i].real; + r[5].imag += ptr[i].imag; + } + } + } + /*handles remainder cases*/ + if(rem) + { + if ( bli_is_noconj( conjx_use ) ) + { + + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(c,axpys)( a[i + p*lda], x[i], r[p] ); + } + } + else + { + PRAGMA_SIMD + for(dim_t p = 0; p < 6 ; p++) + { + PASTEMAC(c,axpyjs)( a[i + p*lda], x[i], r[p] ); + } + + } + } + + if ( bli_is_conj( conjat ) ) + { + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,conjs)( r[i] ); + } + } + + /*scaling dot product result with alpha and + * adding the result to vector + */ + for ( dim_t i = 0; i < 6; ++i ) + { + PASTEMAC(c,axpys)( *alpha, r[i], y[i] ); + } + } + else + { + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(c,type); + PASTECH(c,dotxv_ker_ft) kfp_dv + = + bli_cntx_get_l1v_ker_dt( dt, BLIS_DOTXV_KER, cntx ); + + for ( dim_t i = 0; i < b_n; ++i ) + { + scomplex* restrict a1 = a + (0 )*inca + (i )*lda; + scomplex* restrict x1 = x + (0 )*incx; + scomplex* restrict psi1 = y + (i )*incy; + + kfp_dv + ( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx + ); + } + } +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 7edc0a9a1a..537e67038a 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -121,7 +121,8 @@ DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_8 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_4 ) DOTXF_KER_PROT( double, d, dotxf_zen_int_2 ) - +DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_6 ) +DOTXF_KER_PROT( scomplex, c, dotxf_zen_int_6 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) From 505ff8613d7e5512a1f3cbf6ac8f4970577e02b1 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 20 Dec 2021 09:43:13 +0530 Subject: [PATCH 19/63] Removed Arch specific code from BLIS framework. - Removed BLIS_CONFIG_EPYC macro - The code dependent on this macro is handled in one of the three ways -- It is updated to work across platforms. -- Added in architecture/feature specific runtime checks. -- Duplicated in AMD specific files. Build system is updated to pick AMD specific files when library is built for any of the zen architecture AMD-Internal: [CPUPL-1960] Change-Id: I6f9f8018e41fa48eb43ae4245c9c2c361857f43b --- Makefile | 24 +- build/config.mk.in | 4 +- config/amdzen/make_defs.mk | 12 +- config/zen/make_defs.mk | 19 +- config/zen2/make_defs.mk | 16 +- config/zen3/make_defs.mk | 16 +- configure | 3 +- frame/2/gemv/bli_gemv_unf_var1.c | 350 +----- frame/2/gemv/bli_gemv_unf_var1_amd.c | 440 ++++++++ frame/2/gemv/bli_gemv_unf_var2.c | 750 +------------ frame/2/gemv/bli_gemv_unf_var2_amd.c | 879 +++++++++++++++ frame/2/hemv/bli_hemv_unf_var1.c | 204 +--- frame/2/hemv/bli_hemv_unf_var1_amd.c | 418 +++++++ frame/2/hemv/bli_hemv_unf_var3.c | 208 +--- frame/2/hemv/bli_hemv_unf_var3_amd.c | 420 +++++++ frame/2/her2/bli_her2_unf_var1.c | 212 ---- frame/2/her2/bli_her2_unf_var1_amd.c | 369 ++++++ frame/2/her2/bli_her2_unf_var4.c | 187 ---- frame/2/her2/bli_her2_unf_var4_amd.c | 354 ++++++ frame/2/trsv/bli_trsv_unf_var1.c | 411 +------ frame/2/trsv/bli_trsv_unf_var1_amd.c | 638 +++++++++++ frame/2/trsv/bli_trsv_unf_var2.c | 786 +------------ frame/2/trsv/bli_trsv_unf_var2_amd.c | 1024 +++++++++++++++++ frame/3/bli_l3_sup_int.c | 128 +-- frame/3/bli_l3_sup_int_amd.c | 352 ++++++ frame/3/gemm/bli_gemm_front.c | 15 +- frame/3/gemm/bli_gemm_front_amd.c | 413 +++++++ frame/base/bli_cpuid.c | 19 + frame/base/bli_cpuid.h | 4 +- frame/compat/bla_amax.c | 208 +--- frame/compat/bla_amax_amd.c | 295 +++++ frame/compat/bla_axpy.c | 395 +------ frame/compat/bla_axpy_amd.c | 462 ++++++++ frame/compat/bla_copy.c | 208 +--- frame/compat/bla_copy_amd.c | 285 +++++ frame/compat/bla_dot.c | 660 +---------- frame/compat/bla_dot_amd.c | 841 ++++++++++++++ frame/compat/bla_gemm.c | 502 --------- frame/compat/bla_gemm_amd.c | 894 +++++++++++++++ frame/compat/bla_gemv.c | 841 +------------- frame/compat/bla_gemv_amd.c | 963 ++++++++++++++++ frame/compat/bla_scal.c | 168 +-- frame/compat/bla_scal_amd.c | 260 +++++ frame/compat/bla_swap.c | 187 +--- frame/compat/bla_swap_amd.c | 268 +++++ frame/compat/bla_trsm.c | 1164 +------------------ frame/compat/bla_trsm_amd.c | 1544 ++++++++++++++++++++++++++ kernels/zen/1/bli_scalv_zen_int10.c | 28 +- kernels/zen/1f/bli_axpyf_zen_int_4.c | 49 +- kernels/zen/1f/bli_axpyf_zen_int_5.c | 173 +-- kernels/zen/1f/bli_axpyf_zen_int_6.c | 26 +- kernels/zen/3/bli_gemm_small.c | 12 +- 52 files changed, 11222 insertions(+), 7886 deletions(-) create mode 100644 frame/2/gemv/bli_gemv_unf_var1_amd.c create mode 100644 frame/2/gemv/bli_gemv_unf_var2_amd.c create mode 100644 frame/2/hemv/bli_hemv_unf_var1_amd.c create mode 100644 frame/2/hemv/bli_hemv_unf_var3_amd.c create mode 100644 frame/2/her2/bli_her2_unf_var1_amd.c create mode 100644 frame/2/her2/bli_her2_unf_var4_amd.c create mode 100644 frame/2/trsv/bli_trsv_unf_var1_amd.c create mode 100644 frame/2/trsv/bli_trsv_unf_var2_amd.c create mode 100644 frame/3/bli_l3_sup_int_amd.c create mode 100644 frame/3/gemm/bli_gemm_front_amd.c create mode 100644 frame/compat/bla_amax_amd.c create mode 100644 frame/compat/bla_axpy_amd.c create mode 100644 frame/compat/bla_copy_amd.c create mode 100644 frame/compat/bla_dot_amd.c create mode 100644 frame/compat/bla_gemm_amd.c create mode 100644 frame/compat/bla_gemv_amd.c create mode 100644 frame/compat/bla_scal_amd.c create mode 100644 frame/compat/bla_swap_amd.c create mode 100644 frame/compat/bla_trsm_amd.c diff --git a/Makefile b/Makefile index b248d5781a..1658e16de2 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -212,6 +212,27 @@ MK_REFKERN_OBJS := $(foreach arch, $(CONFIG_LIST), \ # Generate object file paths for all of the portable framework source code. MK_FRAME_OBJS := $(call gen-obj-paths-from-src,$(FRAME_SRC_SUFS),$(MK_FRAME_SRC),$(FRAME_PATH),$(BASE_OBJ_FRAME_PATH)) +# AMD has optimized some of the framework files, these optimizations +# may not be compatible with other platforms. +# +# In order to keep main framework code independent of AMD changes, +# AMD has duplicated the files and updated them for example +# frame/compact/bla_gemm.c : generic framework file +# frame/compact/bla_gemm_amd.c : AMD optimized framework file +# Based on the archiecture we choose correct files + +ifeq ($(MK_IS_ARCH_ZEN),yes) +# Build is being done for AMD platforms, remove the objects which +# don't have amd suffix (for which exists AMD specific implementation). +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +FILES_TO_REMOVE := $(subst _amd.o,.o, $(MK_FRAME_AMD_OBJS)) +MK_FRAME_OBJS := $(filter-out $(FILES_TO_REMOVE), $(MK_FRAME_OBJS)) +else +# Build is done for non AMD platforms, remove the amd specific objects +MK_FRAME_AMD_OBJS := $(filter $(BASE_OBJ_FRAME_PATH)/%amd.o, $(MK_FRAME_OBJS)) +MK_FRAME_OBJS := $(filter-out $(MK_FRAME_AMD_OBJS), $(MK_FRAME_OBJS)) +endif + # Generate object file paths for all of the debgu and trace logger. MK_AOCLDTL_OBJS := $(call gen-obj-paths-from-src,$(AOCLDTL_SRC_SUFS),$(MK_AOCLDTL_SRC),$(AOCLDTL_PATH),$(BASE_OBJ_AOCLDTL_PATH)) @@ -1338,4 +1359,3 @@ else @echo "Uninstalling $(@F) from $(@D)/" @- $(RM_F) $@ endif - diff --git a/build/config.mk.in b/build/config.mk.in index 709e0f543c..a880074e8f 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -204,5 +204,7 @@ MK_ENABLE_AOCL_DYNAMIC := @enable_aocl_dynamic@ # BLAS int size MK_BLAS_INT_TYPE_SIZE := @blas_int_type_size@ +MK_IS_ARCH_ZEN := @enable_aocl_zen@ + # end of ifndef CONFIG_MK_INCLUDED conditional block endif diff --git a/config/amdzen/make_defs.mk b/config/amdzen/make_defs.mk index 7697e9ff05..e467461601 100644 --- a/config/amdzen/make_defs.mk +++ b/config/amdzen/make_defs.mk @@ -4,7 +4,7 @@ # An object-based framework for developing high-performance BLAS-like # libraries. # -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -49,16 +49,6 @@ else COPTFLAGS := -O3 endif -# This will add BLIS_CONFIG_EPYC for all framework files -# FIXME: framework files should not have architecture specific -# checks at least at compile time. Once the macro -# is defined it is applicable to every build in the -# Family including any non AMD configuration. -# However, it is still better to define it in makefiles -# instead of headers so we can have slighly more -# control on this. -COPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index be1086a1de..08d8628bec 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -46,25 +46,12 @@ AMD_CONFIG_FILE := amd_config.mk AMD_CONFIG_PATH := $(BASE_SHARE_PATH)/config/zen -include $(AMD_CONFIG_PATH)/$(AMD_CONFIG_FILE) - -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC - - ifeq ($(DEBUG_TYPE),noopt) COPTFLAGS := -O0 else COPTFLAGS := -O3 endif - # # --- Enable ETRACE across the library if enabled ETRACE_ENABLE=[0,1] ----------------------- # @@ -86,10 +73,6 @@ else CRVECFLAGS := $(CKVECFLAGS) endif -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen2/make_defs.mk b/config/zen2/make_defs.mk index ba91f722ab..3b87d35b00 100644 --- a/config/zen2/make_defs.mk +++ b/config/zen2/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,15 +50,7 @@ THIS_CONFIG := zen2 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -111,10 +103,6 @@ endif CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/config/zen3/make_defs.mk b/config/zen3/make_defs.mk index a479acf8a5..8522a1e956 100644 --- a/config/zen3/make_defs.mk +++ b/config/zen3/make_defs.mk @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -50,15 +50,7 @@ THIS_CONFIG := zen3 # general-purpose/configuration-agnostic flags in common.mk. You # may specify additional flags here as needed. -# Since we removed BLIS_CONFIG_EPYC from header file, we need to -# add it here at two places, -# CPPROCFLAGS = This will enable it for framework code -# This flag is used when configure is invoked with specific architecture -# CKOPTFLAGS = This will enable it for architecture specific kernels -# This flag is used for kernels assocaited with this architecture -# irrespective of the configuration it is built for. - -CPPROCFLAGS := -DBLIS_CONFIG_EPYC +CPPROCFLAGS := CMISCFLAGS := CPICFLAGS := CWARNFLAGS := @@ -132,10 +124,6 @@ endif # gcc CROPTFLAGS := $(CKOPTFLAGS) CRVECFLAGS := $(CKVECFLAGS) -# Add this after updating variables for reference kernels -# we don't want this defined for them -CKOPTFLAGS += -DBLIS_CONFIG_EPYC - # Store all of the variables here to new variables containing the # configuration name. $(eval $(call store-make-defs,$(THIS_CONFIG))) diff --git a/configure b/configure index bec498d3cf..f49ea19e5e 100755 --- a/configure +++ b/configure @@ -5,7 +5,7 @@ # libraries. # # Copyright (C) 2014, The University of Texas at Austin -# Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -3370,6 +3370,7 @@ main() | sed -e "s/@enable_aocl_dynamic@/${enable_aocl_dynamic}/g" \ | sed -e "s/@complex_return@/${complex_return}/g" \ | sed -e "s/@blas_int_type_size@/${blas_int_type_size}/g" \ + | sed -e "s/\@enable_aocl_zen\@/${enable_aocl_zen}/g" \ > "${config_mk_out_path}" diff --git a/frame/2/gemv/bli_gemv_unf_var1.c b/frame/2/gemv/bli_gemv_unf_var1.c index 085fe87c45..8162613c18 100644 --- a/frame/2/gemv/bli_gemv_unf_var1.c +++ b/frame/2/gemv/bli_gemv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -104,351 +104,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dgemv_unf_var1 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - - double *A1; - double *y1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - //memory pool declarations for packing vector X. - mem_t mem_bufX; - rntm_t rntm; - double *x_buf = x; - inc_t buf_incx = incx; - - bli_init_once(); - - if (cntx == NULL) - cntx = bli_gks_query_cntx(); - - bli_set_dims_incs_with_trans(transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at); - - conja = bli_extract_conj(transa); - - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(d,type); - double* x1; - double* y1; - PASTECH(d,dotxf_ker_ft) kfp_df; - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - kfp_df - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - if (incx > 1) - { - /* - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL . - */ - - mem_bufX.pblk.buf = NULL; - mem_bufX.pblk.block_size = 0; - mem_bufX.buf_type = 0; - mem_bufX.size = 0; - mem_bufX.pool = NULL; - - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ - - bli_rntm_init_from_global(&rntm); - bli_rntm_set_num_threads_only(1, &rntm); - bli_membrk_rntm_set_membrk(&rntm); - - //calculate the size required for n_elem double elements in vector X. - size_t buffer_size = n_elem * sizeof(double); - -#ifdef BLIS_ENABLE_MEM_TRACING - printf("bli_dgemv_unf_var1(): get mem pool block\n"); -#endif - - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufX.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufX); - - /*Continue packing X if buffer memory is allocated*/ - if ((bli_mem_is_alloc(&mem_bufX))) - { - x_buf = bli_mem_buffer(&mem_bufX); - - //pack X vector with non-unit stride to a temp buffer x_buf with unit stride - for (dim_t x_index = 0; x_index < n_elem; x_index++) - { - *(x_buf + x_index) = *(x + (x_index * incx)); - } - // stride of vector x_buf =1 - buf_incx = 1; - } - } - - dim_t fuse_factor = 8; - dim_t f_temp =0; - - if (n < 4) - { - fuse_factor = 2; - } else if (n < 8) - { - fuse_factor = 4; - } - - - for (i = 0; i < n_iter; i += f) - { - f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); - - //A = a + i * row_increment + 0 * column_increment - A1 = a + (i)*rs_at; - y1 = y + (i)*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - switch (f) - { - case 8: - - bli_ddotxf_zen_int_8( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx); - - break; - default: - - if (f < 4) - { - bli_ddotxf_zen_int_2( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx); - } - else - { - bli_ddotxf_zen_int_4( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x_buf, buf_incx, - beta, - y1, incy, - cntx); - } - } - - f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); - - if (f_temp < fuse_factor) - { - switch (fuse_factor) - { - case 8: - fuse_factor = 4; - break; - case 4: - fuse_factor = 2; - break; - } - } - } - - if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) - { -#ifdef BLIS_ENABLE_MEM_TRACING - printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); -#endif - // Return the buffer to pool - bli_membrk_release(&rntm, &mem_bufX); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_sgemv_unf_var1 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - float* beta, - float* y, inc_t incy, - cntx_t* cntx - ) -{ - - float* A1; - float* x1; - float* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_init_once(); - - if( cntx == NULL ) cntx = bli_gks_query_cntx(); - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_iter, &n_elem, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(s,type); - float* x1 ; - PASTECH(s,dotxf_ker_ft) kfp_df; - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - kfp_df - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 8; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_sdotxf_zen_int_8 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, cs_at, rs_at, - x1, incx, - beta, - y1, incy, - cntx - ); - - } -} - -INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) -#else INSERT_GENTFUNC_BASIC0( gemv_unf_var1 ) -#endif + diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c new file mode 100644 index 0000000000..7228c12f75 --- /dev/null +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -0,0 +1,440 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_iter, &n_elem, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,dotxf_ker_ft) kfp_df; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (i )*rs_at + (0 )*cs_at; \ + x1 = x + (0 )*incy; \ + y1 = y + (i )*incy; \ +\ + /* y1 = beta * y1 + alpha * A1 * x; */ \ + kfp_df \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, cs_at, rs_at, \ + x1, incx, \ + beta, \ + y1, incy, \ + cntx \ + ); \ +\ + } \ +} + +void bli_dgemv_unf_var1 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + double *A1; + double *y1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + //memory pool declarations for packing vector X. + mem_t mem_bufX; + rntm_t rntm; + double *x_buf = x; + inc_t buf_incx = incx; + + bli_init_once(); + + if (cntx == NULL) + cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans(transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at); + + conja = bli_extract_conj(transa); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + PASTECH(d,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + if (incx > 1) + { + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + + mem_bufX.pblk.buf = NULL; + mem_bufX.pblk.block_size = 0; + mem_bufX.buf_type = 0; + mem_bufX.size = 0; + mem_bufX.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + //calculate the size required for n_elem double elements in vector X. + size_t buffer_size = n_elem * sizeof(double); + +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): get mem pool block\n"); +#endif + + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufX.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufX); + + /*Continue packing X if buffer memory is allocated*/ + if ((bli_mem_is_alloc(&mem_bufX))) + { + x_buf = bli_mem_buffer(&mem_bufX); + + //pack X vector with non-unit stride to a temp buffer x_buf with unit stride + for (dim_t x_index = 0; x_index < n_elem; x_index++) + { + *(x_buf + x_index) = *(x + (x_index * incx)); + } + // stride of vector x_buf =1 + buf_incx = 1; + } + } + + dim_t fuse_factor = 8; + dim_t f_temp =0; + + if (n < 4) + { + fuse_factor = 2; + } else if (n < 8) + { + fuse_factor = 4; + } + + for (i = 0; i < n_iter; i += f) + { + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); + + //A = a + i * row_increment + 0 * column_increment + A1 = a + (i)*rs_at; + y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + switch (f) + { + case 8: + + bli_ddotxf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + + break; + default: + + if (f < 4) + { + bli_ddotxf_zen_int_2( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + else + { + bli_ddotxf_zen_int_4( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x_buf, buf_incx, + beta, + y1, incy, + cntx); + } + } + + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + if (f_temp < fuse_factor) + { + switch (fuse_factor) + { + case 8: + fuse_factor = 4; + break; + case 4: + fuse_factor = 2; + break; + } + } + } + + if ((incx > 1) && bli_mem_is_alloc(&mem_bufX)) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf("bli_dgemv_unf_var1(): releasing mem pool block\n"); +#endif + // Return the buffer to pool + bli_membrk_release(&rntm, &mem_bufX); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_sgemv_unf_var1 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + float* A1; + float* x1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + bli_init_once(); + + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_iter, &n_elem, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + const num_t dt = PASTEMAC(s,type); + float* x1 ; + PASTECH(s,dotxf_ker_ft) kfp_df; + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + kfp_df + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 8; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + + } +} + +INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) + diff --git a/frame/2/gemv/bli_gemv_unf_var2.c b/frame/2/gemv/bli_gemv_unf_var2.c index 84a67c3189..d6c21de6df 100644 --- a/frame/2/gemv/bli_gemv_unf_var2.c +++ b/frame/2/gemv/bli_gemv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -137,752 +137,4 @@ void PASTEMAC(ch,varname) \ AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_dgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - double* A1; - double* x1; - dim_t i; - dim_t f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - //memory pool declarations for packing vector Y. - mem_t mem_bufY; - rntm_t rntm; - double *y_buf = y; - inc_t buf_incy = incy; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(d,type); - double* x1; - double* y1; - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - double* zero = PASTEMAC(d,0); - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - NULL - ); - - if( bli_deq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - if (incy > 1) - { - /* - Initialize mem pool buffer to NULL and size to 0 - "buf" and "size" fields are assigned once memory - is allocated from the pool in bli_membrk_acquire_m(). - This will ensure bli_mem_is_alloc() will be passed on - an allocated memory if created or a NULL . - */ - mem_bufY.pblk.buf = NULL; mem_bufY.pblk.block_size = 0; - mem_bufY.buf_type = 0; mem_bufY.size = 0; - mem_bufY.pool = NULL; - - /* In order to get the buffer from pool via rntm access to memory broker - is needed.Following are initializations for rntm */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - //calculate the size required for n_elem double elements in vector Y. - size_t buffer_size = n_elem * sizeof(double); - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var2(): get mem pool block\n" ); - #endif - - /*acquire a Buffer(n_elem*size(double)) from the memory broker - and save the associated mem_t entry to mem_bufY.*/ - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BUFFER_FOR_B_PANEL, - &mem_bufY); - - /*Continue packing Y if buffer memory is allocated*/ - if ((bli_mem_is_alloc( &mem_bufY ))) - { - y_buf = bli_mem_buffer(&mem_bufY); - - //pack Y vector with non-unit stride to a temp buffer y_buf with unit stride - for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) - { - *(y_buf + y_index) = *(y + (y_index * incy)) ; - } - // stride of vector y_buf =1 - buf_incy = 1; - } - } - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - - /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_16x4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y_buf, buf_incy, - NULL - ); - } - if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) - { - //store the result from unit strided y_buf to non-unit strided Y - for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) - { - *(y + (y_index * incy)) = *(y_buf + y_index) ; - } - - #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemv_unf_var2(): releasing mem pool block\n" ); - #endif - // Return the buffer to pool - bli_membrk_release(&rntm , &mem_bufY); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_sgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - float* beta, - float* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - float* A1; - float* x1; - float* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(s,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(s,eq0)( *beta ) ) - { - float* zero = PASTEMAC(s,0); - /* y = 0; */ - PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(s,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - NULL - ); - - if( bli_seq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - /* Query the context for the kernel function pointer and fusing factor. */ - b_fuse = 6; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_saxpyf_zen_int_6 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - - -void bli_zgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - dcomplex* alpha, - dcomplex* a, inc_t rs_a, inc_t cs_a, - dcomplex* x, inc_t incx, - dcomplex* beta, - dcomplex* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - dcomplex* A1; - dcomplex* x1; - dcomplex* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - - /* beta=0 case is hadled by scalv internally */ - /* bli_zscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, - incy, - cntx - );*/ - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(z,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(z,eq0)( *beta ) ) - { - dcomplex* zero = PASTEMAC(z,0); - /* y = 0; */ - PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(z,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - bli_zscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - - if( bli_zeq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - // for non-unit incx, incy and rs_at and conjugate will be added in the next patch - if( (incx == 1 && incy == 1 && rs_at == 1 ) && - !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) - { - // This gemv code deals with the followint conditions only - // 1. incx, incy, and row stride equal to one - // 2. Non conjugate A matrix and X vector - // 3. No Transpose for A Martix - // Rest is taken care by the else part (axpyf implementation) - bli_zgemv_zen_int_4x4 - ( - conja, - conjx, - m, - n, - alpha, - a, rs_at, cs_at, - x, incx, - beta, - y, incy, - NULL - ); - } - else - { - /* fusing factor */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_zaxpyf_zen_int_4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - -void bli_cgemv_unf_var2 - ( - trans_t transa, - conj_t conjx, - dim_t m, - dim_t n, - scomplex* alpha, - scomplex* a, inc_t rs_a, inc_t cs_a, - scomplex* x, inc_t incx, - scomplex* beta, - scomplex* y, inc_t incy, - cntx_t* cntx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); - scomplex* A1; - scomplex* x1; - scomplex* y1; - dim_t i; - dim_t b_fuse, f; - dim_t n_elem, n_iter; - inc_t rs_at, cs_at; - conj_t conja; - - bli_set_dims_incs_with_trans( transa, - m, n, rs_a, cs_a, - &n_elem, &n_iter, &rs_at, &cs_at ); - - conja = bli_extract_conj( transa ); - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - /* y = beta * y; */ - /* beta=0 case is hadled by scalv internally */ - /*bli_cscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, - incy, - cntx - );*/ - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - const num_t dt = PASTEMAC(c,type); - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(c,eq0)( *beta ) ) - { - scomplex* zero = PASTEMAC(c,0); - /* y = 0; */ - PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(c,axpyf_ker_ft) kfp_af; - - /* Query the context for the kernel function pointer and fusing factor. */ - kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - kfp_af - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - cntx - ); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - - bli_cscalv_ex - ( - BLIS_NO_CONJUGATE, - n_elem, - beta, - y, incy, - cntx, - NULL - ); - - - - if( bli_ceq0( *alpha ) ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) - return; - } - - // for non-unit incx, incy and rs_at and conjugate will be added in the next patch - if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && - !bli_is_conj(conja) && !bli_is_conj(conjx) && - !bli_is_trans(transa)) - { - // This gemv code deals with the followint conditions only - // 1. incx, incy, and row stride equal to one - // 2. Non conjugate A matrix and X vector - // 3. No Transpose for A Martix - // Rest is taken care by the else part (axpyf implementation) - bli_cgemv_zen_int_4x4 - ( - conja, - conjx, - m, - n, - alpha, - a, rs_at, cs_at, - x, incx, - beta, - y, incy, - NULL - ); - } - else - { - /* fusing factor. */ - b_fuse = 4; - - for ( i = 0; i < n_iter; i += f ) - { - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - y1 = y + (0 )*incy; - - /* y = y + alpha * A1 * x1; */ - bli_caxpyf_zen_int_4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y1, incy, - NULL - ); - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); -} - - -#else INSERT_GENTFUNC_BASIC0( gemv_unf_var2 ) -#endif diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c new file mode 100644 index 0000000000..d7f5145e31 --- /dev/null +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -0,0 +1,879 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#define BLIS_DGEMV_VAR2_FUSE 4 + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + trans_t transa, \ + conj_t conjx, \ + dim_t m, \ + dim_t n, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); \ +\ + bli_init_once(); \ +\ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A1; \ + ctype* x1; \ + ctype* y1; \ + dim_t i; \ + dim_t b_fuse, f; \ + dim_t n_elem, n_iter; \ + inc_t rs_at, cs_at; \ + conj_t conja; \ +\ + bli_set_dims_incs_with_trans( transa, \ + m, n, rs_a, cs_a, \ + &n_elem, &n_iter, &rs_at, &cs_at ); \ +\ + conja = bli_extract_conj( transa ); \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n_elem, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + for ( i = 0; i < n_iter; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); \ +\ + A1 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + y1 = y + (0 )*incy; \ +\ + /* y = y + alpha * A1 * x1; */ \ + kfp_af \ + ( \ + conja, \ + conjx, \ + n_elem, \ + f, \ + alpha, \ + A1, rs_at, cs_at, \ + x1, incx, \ + y1, incy, \ + cntx \ + ); \ + } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); \ +} + +void bli_dgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + double* A1; + double* x1; + dim_t i; + dim_t f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + //memory pool declarations for packing vector Y. + mem_t mem_bufY; + rntm_t rntm; + double *y_buf = y; + inc_t buf_incy = incy; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(d,type); + double* x1; + double* y1; + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + double* zero = PASTEMAC(d,0); + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + dim_t b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx + ); + + if( bli_deq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + if (incy > 1) + { + /* + Initialize mem pool buffer to NULL and size to 0 + "buf" and "size" fields are assigned once memory + is allocated from the pool in bli_membrk_acquire_m(). + This will ensure bli_mem_is_alloc() will be passed on + an allocated memory if created or a NULL . + */ + mem_bufY.pblk.buf = NULL; mem_bufY.pblk.block_size = 0; + mem_bufY.buf_type = 0; mem_bufY.size = 0; + mem_bufY.pool = NULL; + + /* In order to get the buffer from pool via rntm access to memory broker + is needed.Following are initializations for rntm */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + //calculate the size required for n_elem double elements in vector Y. + size_t buffer_size = n_elem * sizeof(double); + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemv_unf_var2(): get mem pool block\n" ); + #endif + + /*acquire a Buffer(n_elem*size(double)) from the memory broker + and save the associated mem_t entry to mem_bufY.*/ + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BUFFER_FOR_B_PANEL, + &mem_bufY); + + /*Continue packing Y if buffer memory is allocated*/ + if ((bli_mem_is_alloc( &mem_bufY ))) + { + y_buf = bli_mem_buffer(&mem_bufY); + + //pack Y vector with non-unit stride to a temp buffer y_buf with unit stride + for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) + { + *(y_buf + y_index) = *(y + (y_index * incy)) ; + } + // stride of vector y_buf =1 + buf_incy = 1; + } + } + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + + /* y = y + alpha * A1 * x1; */ + bli_daxpyf_zen_int_16x4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx + ); + } + if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) + { + //store the result from unit strided y_buf to non-unit strided Y + for(dim_t y_index = 0 ; y_index < n_elem ; y_index++) + { + *(y + (y_index * incy)) = *(y_buf + y_index) ; + } + + #ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemv_unf_var2(): releasing mem pool block\n" ); + #endif + // Return the buffer to pool + bli_membrk_release(&rntm , &mem_bufY); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_sgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + float* beta, + float* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + float* A1; + float* x1; + float* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(s,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(s,eq0)( *beta ) ) + { + float* zero = PASTEMAC(s,0); + /* y = 0; */ + PASTEMAC2(s,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(s,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx + ); + + if( bli_seq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + /* Query the context for the kernel function pointer and fusing factor. */ + b_fuse = 6; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_saxpyf_zen_int_6 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + +void bli_zgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + dcomplex* beta, + dcomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + dcomplex* A1; + dcomplex* x1; + dcomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + + /* beta=0 case is hadled by scalv internally */ + /* bli_zscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, + incy, + cntx + );*/ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(z,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(z,eq0)( *beta ) ) + { + dcomplex* zero = PASTEMAC(z,0); + /* y = 0; */ + PASTEMAC2(z,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(z,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(z,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + bli_zscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + if( bli_zeq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( (incx == 1 && incy == 1 && rs_at == 1 ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && !bli_is_trans(transa)) + { + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_zgemv_zen_int_4x4 + ( + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + cntx + ); + } + else + { + /* fusing factor */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_zaxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +void bli_cgemv_unf_var2 + ( + trans_t transa, + conj_t conjx, + dim_t m, + dim_t n, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + scomplex* beta, + scomplex* y, inc_t incy, + cntx_t* cntx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + scomplex* A1; + scomplex* x1; + scomplex* y1; + dim_t i; + dim_t b_fuse, f; + dim_t n_elem, n_iter; + inc_t rs_at, cs_at; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here. + bli_init_once(); + if(cntx == NULL) cntx = bli_gks_query_cntx(); + + bli_set_dims_incs_with_trans( transa, + m, n, rs_a, cs_a, + &n_elem, &n_iter, &rs_at, &cs_at ); + + conja = bli_extract_conj( transa ); + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + /* y = beta * y; */ + /* beta=0 case is hadled by scalv internally */ + /*bli_cscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, + incy, + cntx + );*/ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + const num_t dt = PASTEMAC(c,type); + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(c,eq0)( *beta ) ) + { + scomplex* zero = PASTEMAC(c,0); + /* y = 0; */ + PASTEMAC2(c,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(c,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(c,axpyf_ker_ft) kfp_af; + + /* Query the context for the kernel function pointer and fusing factor. */ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + kfp_af + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + + bli_cscalv_ex + ( + BLIS_NO_CONJUGATE, + n_elem, + beta, + y, incy, + cntx, + NULL + ); + + + + if( bli_ceq0( *alpha ) ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3) + return; + } + + // for non-unit incx, incy and rs_at and conjugate will be added in the next patch + if( ( (incx == 1) && (incy == 1) && (rs_at == 1) ) && + !bli_is_conj(conja) && !bli_is_conj(conjx) && + !bli_is_trans(transa)) + { + // This gemv code deals with the followint conditions only + // 1. incx, incy, and row stride equal to one + // 2. Non conjugate A matrix and X vector + // 3. No Transpose for A Martix + // Rest is taken care by the else part (axpyf implementation) + bli_cgemv_zen_int_4x4 + ( + conja, + conjx, + m, + n, + alpha, + a, rs_at, cs_at, + x, incx, + beta, + y, incy, + cntx + ); + } + else + { + /* fusing factor. */ + b_fuse = 4; + + for ( i = 0; i < n_iter; i += f ) + { + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + A1 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + y1 = y + (0 )*incy; + + /* y = y + alpha * A1 * x1; */ + bli_caxpyf_zen_int_4 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y1, incy, + cntx + ); + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + + + diff --git a/frame/2/hemv/bli_hemv_unf_var1.c b/frame/2/hemv/bli_hemv_unf_var1.c index 6790e5bd08..e3229543c0 100644 --- a/frame/2/hemv/bli_hemv_unf_var1.c +++ b/frame/2/hemv/bli_hemv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -216,207 +216,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_post_hemv_8x8 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t cs_a, - dim_t rs_a - ); - -void bli_dhemv_unf_var1 - ( - uplo_t uplo, - conj_t conja, - conj_t conjx, - conj_t conjh, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* one = PASTEMAC(d,1); - double* zero = PASTEMAC(d,0); - double* A10; - double* A11; - double* a10t; - double* alpha11; - double* a21; - double* x0; - double* x1; - double* chi11; - double* y0; - double* y1; - double* y01; - double* psi11; - double* y21; - double conjx_chi11; - double alpha_chi11; - double alpha11_temp; - dim_t i, k, j; - dim_t b_fuse, f; - dim_t n_behind; - dim_t f_ahead, f_behind; - inc_t rs_at, cs_at; - conj_t conj0 = 0, conj1 = 0; - - /* The algorithm will be expressed in terms of the lower triangular - * case;the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. */ - if ( bli_is_lower( uplo ) ) - { - rs_at = rs_a; - cs_at = cs_a; - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; - - /* Query the context for the kernel function pointer and fusing - * factor. */ - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = ((id == BLIS_ARCH_ZEN4) ||(id == BLIS_ARCH_ZEN3) - || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); - if (bamdzen) - { - kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_dotxaxpyf_ker = - bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); - b_fuse = - bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); - } - - for ( i = 0; i < m; i += f ) - { - f = bli_determine_blocksize_dim_f( i, m, b_fuse ); - n_behind = i; - A10 = a + (i )*rs_at + (0 )*cs_at; - A11 = a + (i )*rs_at + (i )*cs_at; - x0 = x + (0 )*incx; - x1 = x + (i )*incx; - y0 = y + (0 )*incy; - y1 = y + (i )*incy; - - /* y1 = y1 + alpha * A10 * x0; (dotxf) */ - /* y0 = y0 + alpha * A10' * x1; (axpyf) */ - kfp_dotxaxpyf_ker - ( - conj0, - conj1, - conjx, - conjx, - n_behind, - f, - alpha, - A10, cs_at, rs_at, - x0, incx, - x1, incx, - one, - y1, incy, - y0, incy, - cntx - ); - - /* y1 = y1 + alpha * A11 * x1; (variant 4) */ - if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) - { - /*this helper function handles unit stride only*/ - bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); - } - else - { - for ( k = 0; k < f; ++k ) - { - f_behind = k; - f_ahead = f - k - 1; - a10t = A11 + (k )*rs_at + (0 )*cs_at; - alpha11 = A11 + (k )*rs_at + (k )*cs_at; - a21 = A11 + (k+1)*rs_at + (k )*cs_at; - chi11 = x1 + (k )*incx; - y01 = y1 + (0 )*incy; - psi11 = y1 + (k )*incy; - y21 = y1 + (k+1)*incy; - - /* y01 = y01 + alpha * a10t' * chi11; */ - PASTEMAC(d,copycjs)( conjx, *chi11, - conjx_chi11 ); - PASTEMAC(d,scal2s)( *alpha, conjx_chi11, - alpha_chi11 ); - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,axpys)( alpha_chi11, - *(a10t + j*cs_at), - *(y01 + j*incy) ); - - PASTEMAC(d,copycjs)( conja, *alpha11, - alpha11_temp ); - - /* psi11 = psi11 + alpha * alpha11 * chi11; */ - PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, - *psi11 ); - - /* y21 = y21 + alpha * a21 * chi11; */ - for ( j = 0; j < f_ahead; ++j ) - { - PASTEMAC(d,axpys)( alpha_chi11, - *(a21 + j*rs_at), - *(y21 + j*incy) ); - } - } - } - } -} -GENTFUNC(float, s, hemv_unf_var1) -GENTFUNC(scomplex, c, hemv_unf_var1) -GENTFUNC(dcomplex, z, hemv_unf_var1) -#else INSERT_GENTFUNC_BASIC0( hemv_unf_var1 ) -#endif diff --git a/frame/2/hemv/bli_hemv_unf_var1_amd.c b/frame/2/hemv/bli_hemv_unf_var1_amd.c new file mode 100644 index 0000000000..6532323d11 --- /dev/null +++ b/frame/2/hemv/bli_hemv_unf_var1_amd.c @@ -0,0 +1,418 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conja, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A10; \ + ctype* A11; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* chi11; \ + ctype* y0; \ + ctype* y1; \ + ctype* y01; \ + ctype* psi11; \ + ctype* y21; \ + ctype conjx_chi11; \ + ctype alpha_chi11; \ + ctype alpha11_temp; \ + dim_t i, k, j; \ + dim_t b_fuse, f; \ + dim_t n_behind; \ + dim_t f_ahead, f_behind; \ + inc_t rs_at, cs_at; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ +\ + conj0 = conja; \ + conj1 = bli_apply_conj( conjh, conja ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ +\ + conj0 = bli_apply_conj( conjh, conja ); \ + conj1 = conja; \ + } \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,dotxaxpyf_ker_ft) kfp_xf; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_xf = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); \ +\ + for ( i = 0; i < m; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); \ + n_behind = i; \ + A10 = a + (i )*rs_at + (0 )*cs_at; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + x0 = x + (0 )*incx; \ + x1 = x + (i )*incx; \ + y0 = y + (0 )*incy; \ + y1 = y + (i )*incy; \ +\ + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ \ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ \ + kfp_xf \ + ( \ + conj0, \ + conj1, \ + conjx, \ + conjx, \ + n_behind, \ + f, \ + alpha, \ + A10, cs_at, rs_at, \ + x0, incx, \ + x1, incx, \ + one, \ + y1, incy, \ + y0, incy, \ + cntx \ + ); \ +\ + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ \ + for ( k = 0; k < f; ++k ) \ + { \ + f_behind = k; \ + f_ahead = f - k - 1; \ + a10t = A11 + (k )*rs_at + (0 )*cs_at; \ + alpha11 = A11 + (k )*rs_at + (k )*cs_at; \ + a21 = A11 + (k+1)*rs_at + (k )*cs_at; \ + chi11 = x1 + (k )*incx; \ + y01 = y1 + (0 )*incy; \ + psi11 = y1 + (k )*incy; \ + y21 = y1 + (k+1)*incy; \ +\ + /* y01 = y01 + alpha * a10t' * chi11; */ \ + PASTEMAC(ch,copycjs)( conjx, *chi11, conjx_chi11 ); \ + PASTEMAC(ch,scal2s)( *alpha, conjx_chi11, alpha_chi11 ); \ + if ( bli_is_conj( conj1 ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ +\ + /* For hemv, explicitly set the imaginary component of alpha11 to + zero. */ \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_temp ); \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( alpha11_temp ); \ +\ + /* psi11 = psi11 + alpha * alpha11 * chi11; */ \ + PASTEMAC(ch,axpys)( alpha_chi11, alpha11_temp, *psi11 ); \ +\ + /* y21 = y21 + alpha * a21 * chi11; */ \ + if ( bli_is_conj( conj0 ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + } \ + } \ +} + +void bli_post_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + +void bli_dhemv_unf_var1 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A10; + double* A11; + double* a10t; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* chi11; + double* y0; + double* y1; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_behind; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + /* Query the context for the kernel function pointer and fusing + * factor. */ + /* Assign kernel function pointer and fusing factor. */ + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_behind = i; + A10 = a + (i )*rs_at + (0 )*cs_at; + A11 = a + (i )*rs_at + (i )*cs_at; + x0 = x + (0 )*incx; + x1 = x + (i )*incx; + y0 = y + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = y1 + alpha * A10 * x0; (dotxf) */ + /* y0 = y0 + alpha * A10' * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_behind, + f, + alpha, + A10, cs_at, rs_at, + x0, incx, + x1, incx, + one, + y1, incy, + y0, incy, + cntx + ); + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (cs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_post_hemv_8x8(A11, x1, y1, alpha, rs_at, cs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, *chi11, + conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,axpys)( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + } +} +GENTFUNC(float, s, hemv_unf_var1) +GENTFUNC(scomplex, c, hemv_unf_var1) +GENTFUNC(dcomplex, z, hemv_unf_var1) + + diff --git a/frame/2/hemv/bli_hemv_unf_var3.c b/frame/2/hemv/bli_hemv_unf_var3.c index abf08dfdaf..b8e26cbcb6 100644 --- a/frame/2/hemv/bli_hemv_unf_var3.c +++ b/frame/2/hemv/bli_hemv_unf_var3.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -216,210 +216,6 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC - -void bli_pre_hemv_8x8 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t cs_a, - dim_t rs_a - ); - -void bli_dhemv_unf_var3 - ( - uplo_t uplo, - conj_t conja, - conj_t conjx, - conj_t conjh, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - double* beta, - double* y, inc_t incy, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* one = PASTEMAC(d,1); - double* zero = PASTEMAC(d,0); - double* A11; - double* A21; - double* a10t; - double* alpha11; - double* a21; - double* x1; - double* x2; - double* chi11; - double* y1; - double* y2; - double* y01; - double* psi11; - double* y21; - double conjx_chi11; - double alpha_chi11; - double alpha11_temp; - dim_t i, k, j; - dim_t b_fuse, f; - dim_t n_ahead; - dim_t f_ahead, f_behind; - inc_t rs_at, cs_at; - conj_t conj0 = 0, conj1 = 0; - - /* The algorithm will be expressed in terms of the lower triangular - * case; the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. */ - if ( bli_is_lower( uplo ) ) - { - rs_at = rs_a; - cs_at = cs_a; - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - } - - /* If beta is zero, use setv. Otherwise, scale by beta. */ - if ( PASTEMAC(d,eq0)( *beta ) ) - { - /* y = 0; */ - PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - zero, - y, incy, - cntx, - NULL - ); - } - else - { - /* y = beta * y; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - beta, - y, incy, - cntx, - NULL - ); - } - - PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; - - arch_t id = bli_arch_query_id(); - bool bamdzen = ((id == BLIS_ARCH_ZEN4) || (id == BLIS_ARCH_ZEN3) - || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN)); - if (bamdzen) - { - kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_dotxaxpyf_ker = - bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); - b_fuse = - bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); - } - - for ( i = 0; i < m; i += f ) - { - f = bli_determine_blocksize_dim_f( i, m, b_fuse ); - n_ahead = m - i - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - y1 = y + (i )*incy; - y2 = y + (i+f)*incy; - - /* y1 = y1 + alpha * A11 * x1; (variant 4) */ - if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) - { - /*this helper function handles unit stride only*/ - bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); - } - else - { - for ( k = 0; k < f; ++k ) - { - f_behind = k; - f_ahead = f - k - 1; - a10t = A11 + (k )*rs_at + (0 )*cs_at; - alpha11 = A11 + (k )*rs_at + (k )*cs_at; - a21 = A11 + (k+1)*rs_at + (k )*cs_at; - chi11 = x1 + (k )*incx; - y01 = y1 + (0 )*incy; - psi11 = y1 + (k )*incy; - y21 = y1 + (k+1)*incy; - - /* y01 = y01 + alpha * a10t' * chi11; */ - PASTEMAC(d,copycjs)( conjx, - *chi11, conjx_chi11 ); - PASTEMAC(d,scal2s)( *alpha, conjx_chi11, - alpha_chi11 ); - { - for ( j = 0; j < f_behind; ++j ) - { - PASTEMAC(d,axpys) - ( alpha_chi11, - *(a10t + j*cs_at), - *(y01 + j*incy) ); - } - } - - PASTEMAC(d,copycjs)( conja, *alpha11, - alpha11_temp ); - - /* psi11 = psi11 + alpha * alpha11 * chi11; */ - PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, - *psi11 ); - - /* y21 = y21 + alpha * a21 * chi11; */ - for ( j = 0; j < f_ahead; ++j ) - { - PASTEMAC(d,axpys)( alpha_chi11, - *(a21 + j*rs_at), - *(y21 + j*incy) ); - } - } - } - - /* y1 = y1 + alpha * A21' * x2; (dotxf) */ - /* y2 = y2 + alpha * A21 * x1; (axpyf) */ - kfp_dotxaxpyf_ker - ( - conj0, - conj1, - conjx, - conjx, - n_ahead, - f, - alpha, - A21, rs_at, cs_at, - x2, incx, - x1, incx, - one, - y1, incy, - y2, incy, - cntx - ); - } -} - -GENTFUNC(float, s, hemv_unf_var3) -GENTFUNC(scomplex, c, hemv_unf_var3) -GENTFUNC(dcomplex, z, hemv_unf_var3) -#else INSERT_GENTFUNC_BASIC0( hemv_unf_var3 ) -#endif + diff --git a/frame/2/hemv/bli_hemv_unf_var3_amd.c b/frame/2/hemv/bli_hemv_unf_var3_amd.c new file mode 100644 index 0000000000..34d40cf5cc --- /dev/null +++ b/frame/2/hemv/bli_hemv_unf_var3_amd.c @@ -0,0 +1,420 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conja, \ + conj_t conjx, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + ctype* beta, \ + ctype* y, inc_t incy, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* zero = PASTEMAC(ch,0); \ + ctype* A11; \ + ctype* A21; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x1; \ + ctype* x2; \ + ctype* chi11; \ + ctype* y1; \ + ctype* y2; \ + ctype* y01; \ + ctype* psi11; \ + ctype* y21; \ + ctype conjx_chi11; \ + ctype alpha_chi11; \ + ctype alpha11_temp; \ + dim_t i, k, j; \ + dim_t b_fuse, f; \ + dim_t n_ahead; \ + dim_t f_ahead, f_behind; \ + inc_t rs_at, cs_at; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ +\ + conj0 = bli_apply_conj( conjh, conja ); \ + conj1 = conja; \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ +\ + conj0 = conja; \ + conj1 = bli_apply_conj( conjh, conja ); \ + } \ +\ + /* If beta is zero, use setv. Otherwise, scale by beta. */ \ + if ( PASTEMAC(ch,eq0)( *beta ) ) \ + { \ + /* y = 0; */ \ + PASTEMAC2(ch,setv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + zero, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ + else \ + { \ + /* y = beta * y; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + beta, \ + y, incy, \ + cntx, \ + NULL \ + ); \ + } \ +\ + PASTECH(ch,dotxaxpyf_ker_ft) kfp_xf; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_xf = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); \ +\ + for ( i = 0; i < m; i += f ) \ + { \ + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); \ + n_ahead = m - i - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ + y1 = y + (i )*incy; \ + y2 = y + (i+f)*incy; \ +\ + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ \ + for ( k = 0; k < f; ++k ) \ + { \ + f_behind = k; \ + f_ahead = f - k - 1; \ + a10t = A11 + (k )*rs_at + (0 )*cs_at; \ + alpha11 = A11 + (k )*rs_at + (k )*cs_at; \ + a21 = A11 + (k+1)*rs_at + (k )*cs_at; \ + chi11 = x1 + (k )*incx; \ + y01 = y1 + (0 )*incy; \ + psi11 = y1 + (k )*incy; \ + y21 = y1 + (k+1)*incy; \ +\ + /* y01 = y01 + alpha * a10t' * chi11; */ \ + PASTEMAC(ch,copycjs)( conjx, *chi11, conjx_chi11 ); \ + PASTEMAC(ch,scal2s)( *alpha, conjx_chi11, alpha_chi11 ); \ + if ( bli_is_conj( conj0 ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a10t + j*cs_at), *(y01 + j*incy) ); \ + } \ +\ + /* For hemv, explicitly set the imaginary component of alpha11 to + zero. */ \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_temp ); \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( alpha11_temp ); \ +\ + /* psi11 = psi11 + alpha * alpha11 * chi11; */ \ + PASTEMAC(ch,axpys)( alpha_chi11, alpha11_temp, *psi11 ); \ +\ + /* y21 = y21 + alpha * a21 * chi11; */ \ + if ( bli_is_conj( conj1 ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( alpha_chi11, *(a21 + j*rs_at), *(y21 + j*incy) ); \ + } \ + } \ +\ + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ \ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ \ + kfp_xf \ + ( \ + conj0, \ + conj1, \ + conjx, \ + conjx, \ + n_ahead, \ + f, \ + alpha, \ + A21, rs_at, cs_at, \ + x2, incx, \ + x1, incx, \ + one, \ + y1, incy, \ + y2, incy, \ + cntx \ + ); \ + } \ +} + +void bli_pre_hemv_8x8 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t cs_a, + dim_t rs_a + ); + +void bli_dhemv_unf_var3 + ( + uplo_t uplo, + conj_t conja, + conj_t conjx, + conj_t conjh, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + double* beta, + double* y, inc_t incy, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* one = PASTEMAC(d,1); + double* zero = PASTEMAC(d,0); + double* A11; + double* A21; + double* a10t; + double* alpha11; + double* a21; + double* x1; + double* x2; + double* chi11; + double* y1; + double* y2; + double* y01; + double* psi11; + double* y21; + double conjx_chi11; + double alpha_chi11; + double alpha11_temp; + dim_t i, k, j; + dim_t b_fuse, f; + dim_t n_ahead; + dim_t f_ahead, f_behind; + inc_t rs_at, cs_at; + conj_t conj0 = 0, conj1 = 0; + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. */ + if ( bli_is_lower( uplo ) ) + { + rs_at = rs_a; + cs_at = cs_a; + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + } + + /* If beta is zero, use setv. Otherwise, scale by beta. */ + if ( PASTEMAC(d,eq0)( *beta ) ) + { + /* y = 0; */ + PASTEMAC2(d,setv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + zero, + y, incy, + cntx, + NULL + ); + } + else + { + /* y = beta * y; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + beta, + y, incy, + cntx, + NULL + ); + } + + PASTECH(d,dotxaxpyf_ker_ft) kfp_dotxaxpyf_ker; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + kfp_dotxaxpyf_ker = bli_ddotxaxpyf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_dotxaxpyf_ker = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXAXPYF_KER, cntx); + b_fuse = + bli_cntx_get_blksz_def_dt( dt, BLIS_XF, cntx ); + } + + for ( i = 0; i < m; i += f ) + { + f = bli_determine_blocksize_dim_f( i, m, b_fuse ); + n_ahead = m - i - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + y1 = y + (i )*incy; + y2 = y + (i+f)*incy; + + /* y1 = y1 + alpha * A11 * x1; (variant 4) */ + if((f == 8) && (incx == 1) && (incy == 1) && (rs_at == 1)) + { + /*this helper function handles unit stride only*/ + bli_pre_hemv_8x8(A11, x1, y1, alpha, cs_at, rs_at); + } + else + { + for ( k = 0; k < f; ++k ) + { + f_behind = k; + f_ahead = f - k - 1; + a10t = A11 + (k )*rs_at + (0 )*cs_at; + alpha11 = A11 + (k )*rs_at + (k )*cs_at; + a21 = A11 + (k+1)*rs_at + (k )*cs_at; + chi11 = x1 + (k )*incx; + y01 = y1 + (0 )*incy; + psi11 = y1 + (k )*incy; + y21 = y1 + (k+1)*incy; + + /* y01 = y01 + alpha * a10t' * chi11; */ + PASTEMAC(d,copycjs)( conjx, + *chi11, conjx_chi11 ); + PASTEMAC(d,scal2s)( *alpha, conjx_chi11, + alpha_chi11 ); + { + for ( j = 0; j < f_behind; ++j ) + { + PASTEMAC(d,axpys) + ( alpha_chi11, + *(a10t + j*cs_at), + *(y01 + j*incy) ); + } + } + + PASTEMAC(d,copycjs)( conja, *alpha11, + alpha11_temp ); + + /* psi11 = psi11 + alpha * alpha11 * chi11; */ + PASTEMAC(d,axpys)( alpha_chi11, alpha11_temp, + *psi11 ); + + /* y21 = y21 + alpha * a21 * chi11; */ + for ( j = 0; j < f_ahead; ++j ) + { + PASTEMAC(d,axpys)( alpha_chi11, + *(a21 + j*rs_at), + *(y21 + j*incy) ); + } + } + } + + /* y1 = y1 + alpha * A21' * x2; (dotxf) */ + /* y2 = y2 + alpha * A21 * x1; (axpyf) */ + kfp_dotxaxpyf_ker + ( + conj0, + conj1, + conjx, + conjx, + n_ahead, + f, + alpha, + A21, rs_at, cs_at, + x2, incx, + x1, incx, + one, + y1, incy, + y2, incy, + cntx + ); + } +} + +GENTFUNC(float, s, hemv_unf_var3) +GENTFUNC(scomplex, c, hemv_unf_var3) +GENTFUNC(dcomplex, z, hemv_unf_var3) + + diff --git a/frame/2/her2/bli_her2_unf_var1.c b/frame/2/her2/bli_her2_unf_var1.c index 299e3d161d..a0aec48f71 100644 --- a/frame/2/her2/bli_her2_unf_var1.c +++ b/frame/2/her2/bli_her2_unf_var1.c @@ -158,217 +158,5 @@ void PASTEMAC(ch,varname) \ } \ } - -#ifdef BLIS_CONFIG_EPYC - -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_trans_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var1 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - const num_t dt = PASTEMAC(d,type); - - double* x0; - double* chi1; - double* y0; - double* psi1; - double* c10t; - double* gamma11; - double alpha0; - double alpha1; - double alpha0_chi1; - double alpha1_psi1; - double alpha0_chi1_psi1; - double conjx0_chi1; - double conjy1_psi1; - double conjy0_psi1; - dim_t i; - dim_t n_behind; - inc_t rs_ct, cs_ct; - conj_t conj0, conj1; - - /* The algorithm will be expressed in terms of the lower triangular - * case;the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - conjx = bli_apply_conj( conjh, conjx ); - conjy = bli_apply_conj( conjh, conjy ); - - PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); - PASTEMAC(d,copys)( *alpha, alpha1 ); - } - - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - conj0 = bli_apply_conj( conjh, conjy ); - conj1 = bli_apply_conj( conjh, conjx ); - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if( (incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - if((n_behind >= 3)) - { - bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); - i+=4; - } - else - { - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i ) - { - n_behind = i; - x0 = x + (0 )*incx; - chi1 = x + (i )*incx; - y0 = y + (0 )*incy; - psi1 = y + (i )*incy; - c10t = c + (i )*rs_ct + (0 )*cs_ct; - gamma11 = c + (i )*rs_ct + (i )*cs_ct; - - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, - * by conjx and conjy. - */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); - - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ - kfp_2v - ( - conj0, - conj1, - n_behind, - &alpha0_chi1, - &alpha1_psi1, - y0, incy, - x0, incx, - c10t, cs_ct, - cntx - ); - - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - - } - } -} - -GENTFUNC(float, s, her2_unf_var1) -GENTFUNC(scomplex, c, her2_unf_var1) -GENTFUNC(dcomplex, z,her2_unf_var1) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var1 ) -#endif diff --git a/frame/2/her2/bli_her2_unf_var1_amd.c b/frame/2/her2/bli_her2_unf_var1_amd.c new file mode 100644 index 0000000000..43a74f49cd --- /dev/null +++ b/frame/2/her2/bli_her2_unf_var1_amd.c @@ -0,0 +1,369 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjy, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* x0; \ + ctype* chi1; \ + ctype* y0; \ + ctype* psi1; \ + ctype* c10t; \ + ctype* gamma11; \ + ctype alpha0; \ + ctype alpha1; \ + ctype alpha0_chi1; \ + ctype alpha1_psi1; \ + ctype alpha0_chi1_psi1; \ + ctype conjx0_chi1; \ + ctype conjy1_psi1; \ + ctype conjy0_psi1; \ + dim_t i; \ + dim_t n_behind; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ +\ + PASTEMAC(ch,copys)( *alpha, alpha0 ); \ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha1 ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx/conjy, but only if we are being invoked + as her2; for syr2, conjx/conjy are unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha0 ); \ + PASTEMAC(ch,copys)( *alpha, alpha1 ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx and/or conjy as needed to arrive at + the effective conjugation for the vector subproblems. */ \ + conj0 = bli_apply_conj( conjh, conjy ); \ + conj1 = bli_apply_conj( conjh, conjx ); \ +\ + PASTECH(ch,axpy2v_ker_ft) kfp_2v; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_behind = i; \ + x0 = x + (0 )*incx; \ + chi1 = x + (i )*incx; \ + y0 = y + (0 )*incy; \ + psi1 = y + (i )*incy; \ + c10t = c + (i )*rs_ct + (0 )*cs_ct; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx and/or conjy to chi1 and/or psi1. */ \ + PASTEMAC(ch,copycjs)( conjx, *chi1, conjx0_chi1 ); \ + PASTEMAC(ch,copycjs)( conjy, *psi1, conjy1_psi1 ); \ + PASTEMAC(ch,copycjs)( conj0, *psi1, conjy0_psi1 ); \ +\ + /* Compute scalars for vector subproblems. */ \ + PASTEMAC(ch,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); \ + PASTEMAC(ch,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); \ +\ + /* Compute alpha * chi1 * conj(psi1) after both chi1 and psi1 have + already been conjugated, if needed, by conjx and conjy. */ \ + PASTEMAC(ch,scal2s)( alpha0_chi1, conjy0_psi1, alpha0_chi1_psi1 ); \ +\ + /* c10t = c10t + alpha * chi1 * y0'; */ \ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ \ + kfp_2v \ + ( \ + conj0, \ + conj1, \ + n_behind, \ + &alpha0_chi1, \ + &alpha1_psi1, \ + y0, incy, \ + x0, incx, \ + c10t, cs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) \ + + conj(alpha) * psi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_trans_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var1 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + const num_t dt = PASTEMAC(d,type); + + double* x0; + double* chi1; + double* y0; + double* psi1; + double* c10t; + double* gamma11; + double alpha0; + double alpha1; + double alpha0_chi1; + double alpha1_psi1; + double alpha0_chi1_psi1; + double conjx0_chi1; + double conjy1_psi1; + double conjy0_psi1; + dim_t i; + dim_t n_behind; + inc_t rs_ct, cs_ct; + conj_t conj0, conj1; + + /* The algorithm will be expressed in terms of the lower triangular + * case;the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + PASTEMAC(d,copycjs)( conjh, *alpha, alpha1 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + conjx = bli_apply_conj( conjh, conjx ); + conjy = bli_apply_conj( conjh, conjy ); + + PASTEMAC(d,copycjs)( conjh, *alpha, alpha0 ); + PASTEMAC(d,copys)( *alpha, alpha1 ); + } + + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + conj0 = bli_apply_conj( conjh, conjy ); + conj1 = bli_apply_conj( conjh, conjx ); + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + if((n_behind >= 3)) + { + bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + i+=4; + } + else + { + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i ) + { + n_behind = i; + x0 = x + (0 )*incx; + chi1 = x + (i )*incx; + y0 = y + (0 )*incy; + psi1 = y + (i )*incy; + c10t = c + (i )*rs_ct + (0 )*cs_ct; + gamma11 = c + (i )*rs_ct + (i )*cs_ct; + + /* Apply conjx and/or conjy to chi1 and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have already been conjugated, if needed, + * by conjx and conjy. + */ + PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, + alpha0_chi1_psi1 ); + + /* c10t = c10t + alpha * chi1 * y0'; */ + /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + kfp_2v + ( + conj0, + conj1, + n_behind, + &alpha0_chi1, + &alpha1_psi1, + y0, incy, + x0, incx, + c10t, cs_ct, + cntx + ); + + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) + + conj(alpha) * psi1 * conj(chi1); */ + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + + } + } +} + +GENTFUNC(float, s, her2_unf_var1) +GENTFUNC(scomplex, c, her2_unf_var1) +GENTFUNC(dcomplex, z,her2_unf_var1) + + diff --git a/frame/2/her2/bli_her2_unf_var4.c b/frame/2/her2/bli_her2_unf_var4.c index e39c7224c4..3dea31d53e 100644 --- a/frame/2/her2/bli_her2_unf_var4.c +++ b/frame/2/her2/bli_her2_unf_var4.c @@ -166,192 +166,5 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -/** - * Following is function declaration - * that computes her2 for transposed case. - * It handles triangular part of matrix and - * remaining computation in optimal way to - * gain performance improvement. - * a is triangular matrix, x and y are vectors - */ -void bli_dher2_zen_int_4 - ( - double *a, - double *x, - double *y, - double *alpha, - dim_t m, - dim_t lda - ); - -void bli_dher2_unf_var4 - ( - uplo_t uplo, - conj_t conjx, - conj_t conjy, - conj_t conjh, - dim_t m, - double* alpha, - double* x, inc_t incx, - double* y, inc_t incy, - double* c, inc_t rs_c, inc_t cs_c, - cntx_t* cntx - ) -{ - - double* chi1; - double* x2; - double* psi1; - double* y2; - double* gamma11; - double* c21; - double alpha0; - double alpha0_psi1; - double alpha1_chi1; - double alpha0_chi1_psi1; - dim_t i; - dim_t n_ahead; - inc_t rs_ct, cs_ct; - - const num_t dt = PASTEMAC(d,type); - - /* The algorithm will be expressed in terms of the lower triangular - * case; the upper triangular case is supported by swapping the row - * and column strides of A and toggling some conj parameters. - */ - if ( bli_is_lower( uplo ) ) - { - rs_ct = rs_c; - cs_ct = cs_c; - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - else /* if ( bli_is_upper( uplo ) ) */ - { - rs_ct = cs_c; - cs_ct = rs_c; - - /* Toggle conjugation of conjx/conjy, but only if we are being - * invoked as her2; for syr2, conjx/conjy are unchanged. - */ - - PASTEMAC(d,copys)( *alpha, alpha0 ); - } - /* Apply conjh (which carries the conjugation component of the - * Hermitian transpose, if applicable) to conjx and/or conjy as - * needed to arrive at the effective conjugation for the vector - * subproblems. - */ - - PASTECH(d,axpy2v_ker_ft) kfp_2v; - - /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - - if((incx == 1) && (incy == 1) && (rs_ct == 1)) - { - for ( i = 0; i < m; ) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - if((n_ahead >= 3)) - { - bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); - i+= 4; - } - else - { - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - i+=1; - } - } - } - else - { - for ( i = 0; i < m; ++i) - { - n_ahead = m - i - 1; - chi1 = x + (i ) * incx; - x2 = x + (i+1) * incx; - psi1 = y + (i ) * incy; - y2 = y + (i+1) * incy; - gamma11 = c + (i ) + (i )*cs_ct; - c21 = c + (i+1) + (i )*cs_ct; - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ - PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); - - /* c21 = c21 + alpha * x2 * conj(psi1); */ - /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ - - kfp_2v - ( - conjx, - conjy, - n_ahead, - &alpha0_psi1, - &alpha1_chi1, - x2, incx, - y2, incy, - c21, rs_ct, - cntx - ); - - - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - } - } -} - -GENTFUNC(float, s, her2_unf_var4) -GENTFUNC(scomplex, c, her2_unf_var4) -GENTFUNC(dcomplex, z,her2_unf_var4) -#else INSERT_GENTFUNC_BASIC0( her2_unf_var4 ) -#endif diff --git a/frame/2/her2/bli_her2_unf_var4_amd.c b/frame/2/her2/bli_her2_unf_var4_amd.c new file mode 100644 index 0000000000..4d77397cd2 --- /dev/null +++ b/frame/2/her2/bli_her2_unf_var4_amd.c @@ -0,0 +1,354 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uplo, \ + conj_t conjx, \ + conj_t conjy, \ + conj_t conjh, \ + dim_t m, \ + ctype* alpha, \ + ctype* x, inc_t incx, \ + ctype* y, inc_t incy, \ + ctype* c, inc_t rs_c, inc_t cs_c, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* chi1; \ + ctype* x2; \ + ctype* psi1; \ + ctype* y2; \ + ctype* gamma11; \ + ctype* c21; \ + ctype alpha0; \ + ctype alpha1; \ + ctype alpha0_psi1; \ + ctype alpha1_chi1; \ + ctype alpha0_chi1_psi1; \ + ctype conjy0_psi1; \ + ctype conjx1_chi1; \ + ctype conjx0_chi1; \ + dim_t i; \ + dim_t n_ahead; \ + inc_t rs_ct, cs_ct; \ + conj_t conj0, conj1; \ + conj_t conjh_conjx; \ + conj_t conjh_conjy; \ +\ + /* Eliminate unused variable warnings. */ \ + ( void )conjh_conjx; \ + ( void )conjh_conjy; \ +\ + /* The algorithm will be expressed in terms of the lower triangular case; + the upper triangular case is supported by swapping the row and column + strides of A and toggling some conj parameters. */ \ + if ( bli_is_lower( uplo ) ) \ + { \ + rs_ct = rs_c; \ + cs_ct = cs_c; \ +\ + PASTEMAC(ch,copys)( *alpha, alpha0 ); \ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha1 ); \ + } \ + else /* if ( bli_is_upper( uplo ) ) */ \ + { \ + rs_ct = cs_c; \ + cs_ct = rs_c; \ +\ + /* Toggle conjugation of conjx/conjy, but only if we are being invoked + as her2; for syr2, conjx/conjy are unchanged. */ \ + conjx = bli_apply_conj( conjh, conjx ); \ + conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTEMAC(ch,copycjs)( conjh, *alpha, alpha0 ); \ + PASTEMAC(ch,copys)( *alpha, alpha1 ); \ + } \ +\ + /* Apply conjh (which carries the conjugation component of the Hermitian + transpose, if applicable) to conjx and/or conjy as needed to arrive at + the effective conjugation for the vector subproblems. */ \ + conj0 = conjx; \ + conj1 = conjy; \ + conjh_conjx = bli_apply_conj( conjh, conjx ); \ + conjh_conjy = bli_apply_conj( conjh, conjy ); \ +\ + PASTECH(ch,axpy2v_ker_ft) kfp_2v; \ +\ + /* Query the context for the kernel function pointer. */ \ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); \ +\ + for ( i = 0; i < m; ++i ) \ + { \ + n_ahead = m - i - 1; \ + chi1 = x + (i )*incx; \ + x2 = x + (i+1)*incx; \ + psi1 = y + (i )*incy; \ + y2 = y + (i+1)*incy; \ + gamma11 = c + (i )*rs_ct + (i )*cs_ct; \ + c21 = c + (i+1)*rs_ct + (i )*cs_ct; \ +\ + /* Apply conjx and/or conjy to chi1 and/or psi1. */ \ + PASTEMAC(ch,copycjs)( conjh_conjy, *psi1, conjy0_psi1 ); \ + PASTEMAC(ch,copycjs)( conjh_conjx, *chi1, conjx1_chi1 ); \ + PASTEMAC(ch,copycjs)( conj0, *chi1, conjx0_chi1 ); \ +\ + /* Compute scalars for vector subproblems. */ \ + PASTEMAC(ch,scal2s)( alpha0, conjy0_psi1, alpha0_psi1 ); \ + PASTEMAC(ch,scal2s)( alpha1, conjx1_chi1, alpha1_chi1 ); \ +\ + /* Compute alpha * chi1 * conj(psi1) after both chi1 and psi1 have + already been conjugated, if needed, by conjx and conjy. */ \ + PASTEMAC(ch,scal2s)( alpha0_psi1, conjx0_chi1, alpha0_chi1_psi1 ); \ +\ + /* c21 = c21 + alpha * x2 * conj(psi1); */ \ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ \ + kfp_2v \ + ( \ + conj0, \ + conj1, \ + n_ahead, \ + &alpha0_psi1, \ + &alpha1_chi1, \ + x2, incx, \ + y2, incy, \ + c21, rs_ct, \ + cntx \ + ); \ +\ + /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) \ + + conj(alpha) * psi1 * conj(chi1); */ \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ + PASTEMAC(ch,adds)( alpha0_chi1_psi1, *gamma11 ); \ +\ + /* For her2, explicitly set the imaginary component of gamma11 to + zero. */ \ + if ( bli_is_conj( conjh ) ) \ + PASTEMAC(ch,seti0s)( *gamma11 ); \ + } \ +} + +/** + * Following is function declaration + * that computes her2 for transposed case. + * It handles triangular part of matrix and + * remaining computation in optimal way to + * gain performance improvement. + * a is triangular matrix, x and y are vectors + */ +void bli_dher2_zen_int_4 + ( + double *a, + double *x, + double *y, + double *alpha, + dim_t m, + dim_t lda + ); + +void bli_dher2_unf_var4 + ( + uplo_t uplo, + conj_t conjx, + conj_t conjy, + conj_t conjh, + dim_t m, + double* alpha, + double* x, inc_t incx, + double* y, inc_t incy, + double* c, inc_t rs_c, inc_t cs_c, + cntx_t* cntx + ) +{ + + double* chi1; + double* x2; + double* psi1; + double* y2; + double* gamma11; + double* c21; + double alpha0; + double alpha0_psi1; + double alpha1_chi1; + double alpha0_chi1_psi1; + dim_t i; + dim_t n_ahead; + inc_t rs_ct, cs_ct; + + const num_t dt = PASTEMAC(d,type); + + /* The algorithm will be expressed in terms of the lower triangular + * case; the upper triangular case is supported by swapping the row + * and column strides of A and toggling some conj parameters. + */ + if ( bli_is_lower( uplo ) ) + { + rs_ct = rs_c; + cs_ct = cs_c; + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + else /* if ( bli_is_upper( uplo ) ) */ + { + rs_ct = cs_c; + cs_ct = rs_c; + + /* Toggle conjugation of conjx/conjy, but only if we are being + * invoked as her2; for syr2, conjx/conjy are unchanged. + */ + + PASTEMAC(d,copys)( *alpha, alpha0 ); + } + /* Apply conjh (which carries the conjugation component of the + * Hermitian transpose, if applicable) to conjx and/or conjy as + * needed to arrive at the effective conjugation for the vector + * subproblems. + */ + + PASTECH(d,axpy2v_ker_ft) kfp_2v; + + /* Query the context for the kernel function pointer. */ + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + + if((incx == 1) && (incy == 1) && (rs_ct == 1)) + { + for ( i = 0; i < m; ) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + if((n_ahead >= 3)) + { + bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + i+= 4; + } + else + { + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + i+=1; + } + } + } + else + { + for ( i = 0; i < m; ++i) + { + n_ahead = m - i - 1; + chi1 = x + (i ) * incx; + x2 = x + (i+1) * incx; + psi1 = y + (i ) * incy; + y2 = y + (i+1) * incy; + gamma11 = c + (i ) + (i )*cs_ct; + c21 = c + (i+1) + (i )*cs_ct; + + /* Compute scalars for vector subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) after both chi1 + * and psi1 have + already been conjugated, if needed, by conjx and + conjy. */ + PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, + alpha0_chi1_psi1 ); + + /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ + + kfp_2v + ( + conjx, + conjy, + n_ahead, + &alpha0_psi1, + &alpha1_chi1, + x2, incx, + y2, incy, + c21, rs_ct, + cntx + ); + + + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + } + } +} + +GENTFUNC(float, s, her2_unf_var4) +GENTFUNC(scomplex, c, her2_unf_var4) +GENTFUNC(dcomplex, z,her2_unf_var4) + + diff --git a/frame/2/trsv/bli_trsv_unf_var1.c b/frame/2/trsv/bli_trsv_unf_var1.c index 4f19e1ac5e..55e28a4417 100644 --- a/frame/2/trsv/bli_trsv_unf_var1.c +++ b/frame/2/trsv/bli_trsv_unf_var1.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -231,413 +231,4 @@ void PASTEMAC(ch,varname) \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dtrsv_unf_var1 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - cntx_t* cntx - ) -{ - - double* one = PASTEMAC(d,1); - double* minus_one = PASTEMAC(d,m1); - double* A10; - double* A11; - double* A12; - double* a10t; - double* alpha11; - double* a12t; - double* x0; - double* x1; - double* x2; - double* x01; - double* chi11; - double* x21; - double alpha11_conj; - double rho1; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_behind, f_behind; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(d,dotxf_ker_ft) kfp_df; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_df = bli_ddotxf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(d,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_behind = iter; - A11 = a + (i )*rs_at + (i )*cs_at; - A12 = a + (i )*rs_at + (i+f)*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 - A12 * x2; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A12, cs_at, rs_at, - x2, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_behind = k; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a12t = A11 + (l )*rs_at + (l+1)*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 - a12t * x21; */ - PASTEMAC(d,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - PASTEMAC(d,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - } - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_behind = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A10 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 - A10 * x0; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A10, cs_at, rs_at, - x0, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_behind = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a10t = A11 + (l )*rs_at + (0 )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 - a10t * x01; */ - PASTEMAC(d,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(d,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - PASTEMAC(d,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - } - } - } -} - -void bli_strsv_unf_var1 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - cntx_t* cntx - ) -{ - - float* one = PASTEMAC(s,1); - float* minus_one = PASTEMAC(s,m1); - float* A10; - float* A11; - float* A12; - float* a10t; - float* alpha11; - float* a12t; - float* x0; - float* x1; - float* x2; - float* x01; - float* chi11; - float* x21; - float alpha11_conj; - float rho1; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_behind, f_behind; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(s,dotxf_ker_ft) kfp_df; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_df = bli_sdotxf_zen_int_8; - b_fuse = 8; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - num_t dt = PASTEMAC(s,type); - kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); - - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_behind = iter; - A11 = a + (i )*rs_at + (i )*cs_at; - A12 = a + (i )*rs_at + (i+f)*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 - A12 * x2; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A12, cs_at, rs_at, - x2, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_behind = k; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a12t = A11 + (l )*rs_at + (l+1)*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 - a12t * x21; */ - PASTEMAC(s,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); - } - PASTEMAC(s,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); - } - } - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_behind = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A10 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 - A10 * x0; */ - kfp_df - ( - conja, - BLIS_NO_CONJUGATE, - n_behind, - f, - minus_one, - A10, cs_at, rs_at, - x0, incx, - one, - x1, incx, - cntx - ); - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_behind = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a10t = A11 + (l )*rs_at + (0 )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 - a10t * x01; */ - PASTEMAC(s,set0s)( rho1 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - else - { - for ( j = 0; j < f_behind; ++j ) - PASTEMAC(s,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); - } - PASTEMAC(s,subs)( rho1, *chi11 ); - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); - } - } - } - } -} - -INSERT_GENTFUNC_BASIC0_CZ( trsv_unf_var1 ) -#else INSERT_GENTFUNC_BASIC0( trsv_unf_var1 ) -#endif diff --git a/frame/2/trsv/bli_trsv_unf_var1_amd.c b/frame/2/trsv/bli_trsv_unf_var1_amd.c new file mode 100644 index 0000000000..4f026f2c6a --- /dev/null +++ b/frame/2/trsv/bli_trsv_unf_var1_amd.c @@ -0,0 +1,638 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + cntx_t* cntx \ + ) \ +{ \ + if(cntx == NULL) cntx = bli_gks_query_cntx(); \ + const num_t dt = PASTEMAC(ch,type); \ +\ + ctype* one = PASTEMAC(ch,1); \ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A10; \ + ctype* A11; \ + ctype* A12; \ + ctype* a10t; \ + ctype* alpha11; \ + ctype* a12t; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype rho1; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_behind, f_behind; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,dotxf_ker_ft) kfp_df; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_behind = iter; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A12 = a + (i )*rs_at + (i+f)*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 - A12 * x2; */ \ + kfp_df \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_behind, \ + f, \ + minus_one, \ + A12, cs_at, rs_at, \ + x2, incx, \ + one, \ + x1, incx, \ + cntx \ + ); \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_behind = k; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a12t = A11 + (l )*rs_at + (l+1)*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 - a12t * x21; */ \ + PASTEMAC(ch,set0s)( rho1 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); \ + } \ + PASTEMAC(ch,subs)( rho1, *chi11 ); \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ + } \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_behind = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A10 = a + (i )*rs_at + (0 )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 - A10 * x0; */ \ + kfp_df \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_behind, \ + f, \ + minus_one, \ + A10, cs_at, rs_at, \ + x0, incx, \ + one, \ + x1, incx, \ + cntx \ + ); \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_behind = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a10t = A11 + (l )*rs_at + (0 )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 - a10t * x01; */ \ + PASTEMAC(ch,set0s)( rho1 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \ + } \ + else \ + { \ + for ( j = 0; j < f_behind; ++j ) \ + PASTEMAC(ch,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); \ + } \ + PASTEMAC(ch,subs)( rho1, *chi11 ); \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ + } \ + } \ + } \ +} + +void bli_dtrsv_unf_var1 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + cntx_t* cntx + ) +{ + + double* one = PASTEMAC(d,1); + double* minus_one = PASTEMAC(d,m1); + double* A10; + double* A11; + double* A12; + double* a10t; + double* alpha11; + double* a12t; + double* x0; + double* x1; + double* x2; + double* x01; + double* chi11; + double* x21; + double alpha11_conj; + double rho1; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_behind, f_behind; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + /* x = alpha * x; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(d,dotxf_ker_ft) kfp_df; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_df = bli_ddotxf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(d,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_behind = iter; + A11 = a + (i )*rs_at + (i )*cs_at; + A12 = a + (i )*rs_at + (i+f)*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 - A12 * x2; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A12, cs_at, rs_at, + x2, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_behind = k; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a12t = A11 + (l )*rs_at + (l+1)*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 - a12t * x21; */ + PASTEMAC(d,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + PASTEMAC(d,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + } + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_behind = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A10 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 - A10 * x0; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A10, cs_at, rs_at, + x0, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_behind = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a10t = A11 + (l )*rs_at + (0 )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 - a10t * x01; */ + PASTEMAC(d,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(d,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + PASTEMAC(d,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + } + } + } +} + +void bli_strsv_unf_var1 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + cntx_t* cntx + ) +{ + + float* one = PASTEMAC(s,1); + float* minus_one = PASTEMAC(s,m1); + float* A10; + float* A11; + float* A12; + float* a10t; + float* alpha11; + float* a12t; + float* x0; + float* x1; + float* x2; + float* x01; + float* chi11; + float* x21; + float alpha11_conj; + float rho1; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_behind, f_behind; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + /* x = alpha * x; */ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(s,dotxf_ker_ft) kfp_df; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_df = bli_sdotxf_zen_int_8; + b_fuse = 8; + } + else + { + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + num_t dt = PASTEMAC(s,type); + kfp_df = bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_DF, cntx ); + + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_behind = iter; + A11 = a + (i )*rs_at + (i )*cs_at; + A12 = a + (i )*rs_at + (i+f)*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 - A12 * x2; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A12, cs_at, rs_at, + x2, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_behind = k; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a12t = A11 + (l )*rs_at + (l+1)*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 - a12t * x21; */ + PASTEMAC(s,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dotjs)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dots)( *(a12t + j*cs_at), *(x21 + j*incx), rho1 ); + } + PASTEMAC(s,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); + } + } + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_behind = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A10 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 - A10 * x0; */ + kfp_df + ( + conja, + BLIS_NO_CONJUGATE, + n_behind, + f, + minus_one, + A10, cs_at, rs_at, + x0, incx, + one, + x1, incx, + cntx + ); + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_behind = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a10t = A11 + (l )*rs_at + (0 )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 - a10t * x01; */ + PASTEMAC(s,set0s)( rho1 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dotjs)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + else + { + for ( j = 0; j < f_behind; ++j ) + PASTEMAC(s,dots)( *(a10t + j*cs_at), *(x01 + j*incx), rho1 ); + } + PASTEMAC(s,subs)( rho1, *chi11 ); + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s,invscals)( alpha11_conj, *chi11 ); + } + } + } + } +} + +INSERT_GENTFUNC_BASIC0_CZ( trsv_unf_var1 ) + diff --git a/frame/2/trsv/bli_trsv_unf_var2.c b/frame/2/trsv/bli_trsv_unf_var2.c index 7ece8f8470..9eb02781a4 100644 --- a/frame/2/trsv/bli_trsv_unf_var2.c +++ b/frame/2/trsv/bli_trsv_unf_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -228,789 +228,5 @@ void PASTEMAC(ch,varname) \ } \ } \ } -#ifdef BLIS_CONFIG_EPYC -void bli_dtrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - double* alpha, - double* a, inc_t rs_a, inc_t cs_a, - double* x, inc_t incx, - cntx_t* cntx - ) -{ - double* minus_one = PASTEMAC(d,m1); - double* A01; - double* A11; - double* A21; - double* a01; - double* alpha11; - double* a21; - double* x0; - double* x1; - double* x2; - double* x01; - double* chi11; - double* x21; - double alpha11_conj; - double minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if ( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(d,axpyf_ker_ft) kfp_af; - - /* Assign kernel function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_daxpyf_zen_int_16x4; - b_fuse = 4; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(d,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_strsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - float* alpha, - float* a, inc_t rs_a, inc_t cs_a, - float* x, inc_t incx, - cntx_t* cntx - ) -{ - - float* minus_one = PASTEMAC(s, m1); - float* A01; - float* A11; - float* A21; - float* a01; - float* alpha11; - float* a21; - float* x0; - float* x1; - float* x2; - float* x01; - float* chi11; - float* x21; - float alpha11_conj; - float minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(s, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(s, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_saxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); - } - - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(s, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_ztrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - dcomplex* alpha, - dcomplex* a, inc_t rs_a, inc_t cs_a, - dcomplex* x, inc_t incx, - cntx_t* cntx - ) -{ - - dcomplex* minus_one = PASTEMAC(z, m1); - dcomplex* A01; - dcomplex* A11; - dcomplex* A21; - dcomplex* a01; - dcomplex* alpha11; - dcomplex* a21; - dcomplex* x0; - dcomplex* x1; - dcomplex* x2; - dcomplex* x01; - dcomplex* chi11; - dcomplex* x21; - dcomplex alpha11_conj; - dcomplex minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(z, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_zaxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); - } - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(z, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -void bli_ctrsv_unf_var2 - ( - uplo_t uploa, - trans_t transa, - diag_t diaga, - dim_t m, - scomplex* alpha, - scomplex* a, inc_t rs_a, inc_t cs_a, - scomplex* x, inc_t incx, - cntx_t* cntx - ) -{ - - scomplex* minus_one = PASTEMAC(c, m1); - scomplex* A01; - scomplex* A11; - scomplex* A21; - scomplex* a01; - scomplex* alpha11; - scomplex* a21; - scomplex* x0; - scomplex* x1; - scomplex* x2; - scomplex* x01; - scomplex* chi11; - scomplex* x21; - scomplex alpha11_conj; - scomplex minus_chi11; - dim_t iter, i, k, j, l; - dim_t b_fuse, f; - dim_t n_ahead, f_ahead; - inc_t rs_at, cs_at; - uplo_t uploa_trans; - conj_t conja; - - /* x = alpha * x; */ - PASTEMAC2(c, scalv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - m, - alpha, - x, incx, - cntx, - NULL - ); - - if( bli_does_notrans( transa ) ) - { - rs_at = rs_a; - cs_at = cs_a; - uploa_trans = uploa; - } - else /* if ( bli_does_trans( transa ) ) */ - { - rs_at = cs_a; - cs_at = rs_a; - uploa_trans = bli_uplo_toggled( uploa ); - } - - conja = bli_extract_conj( transa ); - - PASTECH(c, axpyf_ker_ft) kfp_af; - - /* Assign function pointer and fusing factor. */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - kfp_af = bli_caxpyf_zen_int_5; - b_fuse = 5; - } - else - { - if ( cntx == NULL ) cntx = bli_gks_query_cntx(); - kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); - b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); - } - /* We reduce all of the possible cases down to just lower/upper. */ - if ( bli_is_upper( uploa_trans ) ) - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); - i = m - iter - f; - n_ahead = i; - A11 = a + (i )*rs_at + (i )*cs_at; - A01 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x0 = x + (0 )*incx; - - /* x1 = x1 / triu( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = f - k - 1; - f_ahead = l; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a01 = A11 + (0 )*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x01 = x1 + (0 )*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); - } - - /* x01 = x01 - chi11 * a01; */ - PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); - } - } - - /* x0 = x0 - A01 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A01, rs_at, cs_at, - x1, incx, - x0, incx, - cntx - ); - } - } - else /* if ( bli_is_lower( uploa_trans ) ) */ - { - for ( iter = 0; iter < m; iter += f ) - { - f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); - i = iter; - n_ahead = m - iter - f; - A11 = a + (i )*rs_at + (i )*cs_at; - A21 = a + (i+f)*rs_at + (i )*cs_at; - x1 = x + (i )*incx; - x2 = x + (i+f)*incx; - - /* x1 = x1 / tril( A11 ); */ - for ( k = 0; k < f; ++k ) - { - l = k; - f_ahead = f - k - 1; - alpha11 = A11 + (l )*rs_at + (l )*cs_at; - a21 = A11 + (l+1)*rs_at + (l )*cs_at; - chi11 = x1 + (l )*incx; - x21 = x1 + (l+1)*incx; - - /* chi11 = chi11 / alpha11; */ - if ( bli_is_nonunit_diag( diaga ) ) - { - PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); - PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); - } - - /* x21 = x21 - chi11 * a21; */ - PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); - if ( bli_is_conj( conja ) ) - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - else - { - for ( j = 0; j < f_ahead; ++j ) - PASTEMAC(c, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); - } - } - - /* x2 = x2 - A21 * x1; */ - kfp_af - ( - conja, - BLIS_NO_CONJUGATE, - n_ahead, - f, - minus_one, - A21, rs_at, cs_at, - x1, incx, - x2, incx, - cntx - ); - } - } -} - -#else INSERT_GENTFUNC_BASIC0( trsv_unf_var2 ) -#endif diff --git a/frame/2/trsv/bli_trsv_unf_var2_amd.c b/frame/2/trsv/bli_trsv_unf_var2_amd.c new file mode 100644 index 0000000000..51bbcabab7 --- /dev/null +++ b/frame/2/trsv/bli_trsv_unf_var2_amd.c @@ -0,0 +1,1024 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +#undef GENTFUNC +#define GENTFUNC( ctype, ch, varname ) \ +\ +void PASTEMAC(ch,varname) \ + ( \ + uplo_t uploa, \ + trans_t transa, \ + diag_t diaga, \ + dim_t m, \ + ctype* alpha, \ + ctype* a, inc_t rs_a, inc_t cs_a, \ + ctype* x, inc_t incx, \ + cntx_t* cntx \ + ) \ +{ \ + const num_t dt = PASTEMAC(ch,type); \ +\ + bli_init_once(); \ +\ + if( cntx == NULL ) cntx = bli_gks_query_cntx(); \ +\ + ctype* minus_one = PASTEMAC(ch,m1); \ + ctype* A01; \ + ctype* A11; \ + ctype* A21; \ + ctype* a01; \ + ctype* alpha11; \ + ctype* a21; \ + ctype* x0; \ + ctype* x1; \ + ctype* x2; \ + ctype* x01; \ + ctype* chi11; \ + ctype* x21; \ + ctype alpha11_conj; \ + ctype minus_chi11; \ + dim_t iter, i, k, j, l; \ + dim_t b_fuse, f; \ + dim_t n_ahead, f_ahead; \ + inc_t rs_at, cs_at; \ + uplo_t uploa_trans; \ + conj_t conja; \ +\ + /* x = alpha * x; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + m, \ + alpha, \ + x, incx, \ + cntx, \ + NULL \ + ); \ +\ + if ( bli_does_notrans( transa ) ) \ + { \ + rs_at = rs_a; \ + cs_at = cs_a; \ + uploa_trans = uploa; \ + } \ + else /* if ( bli_does_trans( transa ) ) */ \ + { \ + rs_at = cs_a; \ + cs_at = rs_a; \ + uploa_trans = bli_uplo_toggled( uploa ); \ + } \ +\ + conja = bli_extract_conj( transa ); \ +\ + PASTECH(ch,axpyf_ker_ft) kfp_af; \ +\ + /* Query the context for the kernel function pointer and fusing factor. */ \ + kfp_af = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); \ + b_fuse = bli_cntx_get_blksz_def_dt( dt, BLIS_AF, cntx ); \ +\ + /* We reduce all of the possible cases down to just lower/upper. */ \ + if ( bli_is_upper( uploa_trans ) ) \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); \ + i = m - iter - f; \ + n_ahead = i; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A01 = a + (0 )*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x0 = x + (0 )*incx; \ +\ + /* x1 = x1 / triu( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = f - k - 1; \ + f_ahead = l; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a01 = A11 + (0 )*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x01 = x1 + (0 )*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x01 = x01 - chi11 * a01; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); \ + } \ + } \ +\ + /* x0 = x0 - A01 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A01, rs_at, cs_at, \ + x1, incx, \ + x0, incx, \ + cntx \ + ); \ + } \ + } \ + else /* if ( bli_is_lower( uploa_trans ) ) */ \ + { \ + for ( iter = 0; iter < m; iter += f ) \ + { \ + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); \ + i = iter; \ + n_ahead = m - iter - f; \ + A11 = a + (i )*rs_at + (i )*cs_at; \ + A21 = a + (i+f)*rs_at + (i )*cs_at; \ + x1 = x + (i )*incx; \ + x2 = x + (i+f)*incx; \ +\ + /* x1 = x1 / tril( A11 ); */ \ + for ( k = 0; k < f; ++k ) \ + { \ + l = k; \ + f_ahead = f - k - 1; \ + alpha11 = A11 + (l )*rs_at + (l )*cs_at; \ + a21 = A11 + (l+1)*rs_at + (l )*cs_at; \ + chi11 = x1 + (l )*incx; \ + x21 = x1 + (l+1)*incx; \ +\ + /* chi11 = chi11 / alpha11; */ \ + if ( bli_is_nonunit_diag( diaga ) ) \ + { \ + PASTEMAC(ch,copycjs)( conja, *alpha11, alpha11_conj ); \ + PASTEMAC(ch,invscals)( alpha11_conj, *chi11 ); \ + } \ +\ + /* x21 = x21 - chi11 * a21; */ \ + PASTEMAC(ch,neg2s)( *chi11, minus_chi11 ); \ + if ( bli_is_conj( conja ) ) \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + else \ + { \ + for ( j = 0; j < f_ahead; ++j ) \ + PASTEMAC(ch,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); \ + } \ + } \ +\ + /* x2 = x2 - A21 * x1; */ \ + kfp_af \ + ( \ + conja, \ + BLIS_NO_CONJUGATE, \ + n_ahead, \ + f, \ + minus_one, \ + A21, rs_at, cs_at, \ + x1, incx, \ + x2, incx, \ + cntx \ + ); \ + } \ + } \ +} + +void bli_dtrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + double* alpha, + double* a, inc_t rs_a, inc_t cs_a, + double* x, inc_t incx, + cntx_t* cntx + ) +{ + + double* minus_one = PASTEMAC(d,m1); + double* A01; + double* A11; + double* A21; + double* a01; + double* alpha11; + double* a21; + double* x0; + double* x1; + double* x2; + double* x01; + double* chi11; + double* x21; + double alpha11_conj; + double minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if ( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(d,axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_daxpyf_zen_int_16x4; + b_fuse = 4; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DOUBLE, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DOUBLE, BLIS_AF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(d,copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(d,invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(d,neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(d,axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_strsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + float* alpha, + float* a, inc_t rs_a, inc_t cs_a, + float* x, inc_t incx, + cntx_t* cntx + ) +{ + + float* minus_one = PASTEMAC(s, m1); + float* A01; + float* A11; + float* A21; + float* a01; + float* alpha11; + float* a21; + float* x0; + float* x1; + float* x2; + float* x01; + float* chi11; + float* x21; + float alpha11_conj; + float minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(s, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(s, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_saxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_FLOAT, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_FLOAT, BLIS_AF, cntx ); + } + + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(s, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(s, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(s, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(s, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_ztrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + dcomplex* alpha, + dcomplex* a, inc_t rs_a, inc_t cs_a, + dcomplex* x, inc_t incx, + cntx_t* cntx + ) +{ + + dcomplex* minus_one = PASTEMAC(z, m1); + dcomplex* A01; + dcomplex* A11; + dcomplex* A21; + dcomplex* a01; + dcomplex* alpha11; + dcomplex* a21; + dcomplex* x0; + dcomplex* x1; + dcomplex* x2; + dcomplex* x01; + dcomplex* chi11; + dcomplex* x21; + dcomplex alpha11_conj; + dcomplex minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(z, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(z, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_zaxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_DCOMPLEX, BLIS_AF, cntx ); + } + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(z, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(z, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(z, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(z, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} + +void bli_ctrsv_unf_var2 + ( + uplo_t uploa, + trans_t transa, + diag_t diaga, + dim_t m, + scomplex* alpha, + scomplex* a, inc_t rs_a, inc_t cs_a, + scomplex* x, inc_t incx, + cntx_t* cntx + ) +{ + + scomplex* minus_one = PASTEMAC(c, m1); + scomplex* A01; + scomplex* A11; + scomplex* A21; + scomplex* a01; + scomplex* alpha11; + scomplex* a21; + scomplex* x0; + scomplex* x1; + scomplex* x2; + scomplex* x01; + scomplex* chi11; + scomplex* x21; + scomplex alpha11_conj; + scomplex minus_chi11; + dim_t iter, i, k, j, l; + dim_t b_fuse, f; + dim_t n_ahead, f_ahead; + inc_t rs_at, cs_at; + uplo_t uploa_trans; + conj_t conja; + + // For AMD these APIS are invoked skipping intermediate framework layers + // Hence we need to ensure that cntx is set here + bli_init_once(); + if( cntx == NULL ) cntx = bli_gks_query_cntx(); + + /* x = alpha * x; */ + PASTEMAC2(c, scalv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + m, + alpha, + x, incx, + cntx, + NULL + ); + + if( bli_does_notrans( transa ) ) + { + rs_at = rs_a; + cs_at = cs_a; + uploa_trans = uploa; + } + else /* if ( bli_does_trans( transa ) ) */ + { + rs_at = cs_a; + cs_at = rs_a; + uploa_trans = bli_uplo_toggled( uploa ); + } + + conja = bli_extract_conj( transa ); + + PASTECH(c, axpyf_ker_ft) kfp_af; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + kfp_af = bli_caxpyf_zen_int_5; + b_fuse = 5; + } + else + { + kfp_af = bli_cntx_get_l1f_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYF_KER, cntx ); + b_fuse = bli_cntx_get_blksz_def_dt( BLIS_SCOMPLEX, BLIS_AF, cntx ); + } + /* We reduce all of the possible cases down to just lower/upper. */ + if ( bli_is_upper( uploa_trans ) ) + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_b( iter, m, b_fuse ); + i = m - iter - f; + n_ahead = i; + A11 = a + (i )*rs_at + (i )*cs_at; + A01 = a + (0 )*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x0 = x + (0 )*incx; + + /* x1 = x1 / triu( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = f - k - 1; + f_ahead = l; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a01 = A11 + (0 )*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x01 = x1 + (0 )*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); + } + + /* x01 = x01 - chi11 * a01; */ + PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpyjs)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpys)( minus_chi11, *(a01 + j*rs_at), *(x01 + j*incx) ); + } + } + + /* x0 = x0 - A01 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A01, rs_at, cs_at, + x1, incx, + x0, incx, + cntx + ); + } + } + else /* if ( bli_is_lower( uploa_trans ) ) */ + { + for ( iter = 0; iter < m; iter += f ) + { + f = bli_determine_blocksize_dim_f( iter, m, b_fuse ); + i = iter; + n_ahead = m - iter - f; + A11 = a + (i )*rs_at + (i )*cs_at; + A21 = a + (i+f)*rs_at + (i )*cs_at; + x1 = x + (i )*incx; + x2 = x + (i+f)*incx; + + /* x1 = x1 / tril( A11 ); */ + for ( k = 0; k < f; ++k ) + { + l = k; + f_ahead = f - k - 1; + alpha11 = A11 + (l )*rs_at + (l )*cs_at; + a21 = A11 + (l+1)*rs_at + (l )*cs_at; + chi11 = x1 + (l )*incx; + x21 = x1 + (l+1)*incx; + + /* chi11 = chi11 / alpha11; */ + if ( bli_is_nonunit_diag( diaga ) ) + { + PASTEMAC(c, copycjs)( conja, *alpha11, alpha11_conj ); + PASTEMAC(c, invscals)( alpha11_conj, *chi11 ); + } + + /* x21 = x21 - chi11 * a21; */ + PASTEMAC(c, neg2s)( *chi11, minus_chi11 ); + if ( bli_is_conj( conja ) ) + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpyjs)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + else + { + for ( j = 0; j < f_ahead; ++j ) + PASTEMAC(c, axpys)( minus_chi11, *(a21 + j*rs_at), *(x21 + j*incx) ); + } + } + + /* x2 = x2 - A21 * x1; */ + kfp_af + ( + conja, + BLIS_NO_CONJUGATE, + n_ahead, + f, + minus_one, + A21, rs_at, cs_at, + x1, incx, + x2, incx, + cntx + ); + } + } +} diff --git a/frame/3/bli_l3_sup_int.c b/frame/3/bli_l3_sup_int.c index 7ef4bdd49f..909f480599 100644 --- a/frame/3/bli_l3_sup_int.c +++ b/frame/3/bli_l3_sup_int.c @@ -48,120 +48,6 @@ err_t bli_gemmsup_int { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); -#ifdef BLIS_CONFIG_EPYC - const num_t dt = bli_obj_dt( c ); - const dim_t m = bli_obj_length( c ); - const dim_t n = bli_obj_width( c ); - const dim_t k = bli_obj_width( a ); - const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); - const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); - const bool auto_factor = bli_rntm_auto_factor( rntm ); - const dim_t n_threads = bli_rntm_num_threads( rntm ); - - dim_t jc_new; - dim_t ic_new; - - - //bli_gemmsup_ref_var2 - //bli_gemmsup_ref_var1 - #if 0 - bli_gemmsup_ref_var1n - #else - #endif - const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); - const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || - stor_id == BLIS_RRC || - stor_id == BLIS_RCR || - stor_id == BLIS_CRR ); - #ifdef TRACEVAR - if ( bli_thread_am_ochief( thread ) ) - printf( "bli_l3_sup_int(): var2m primary\n" ); - #endif - - // Don't use the small/unpacked implementation if one of the matrices - // uses general stride. - if ( stor_id == BLIS_XXX ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); - return BLIS_FAILURE; - } - - if ( is_rrr_rrc_rcr_crr ) - { - // This branch handles: - // - rrr rrc rcr crr for row-preferential kernels - // - rcc crc ccr ccc for column-preferential kernels - // - Currently only row-preferential kernels are only supported. - - // calculate number of micropanels in m and n dimensions and - // recalculate the automatic thread factorization based on these number of micropanels - const dim_t mu = m / MR; - const dim_t nu = n / NR; - - // If the parallel thread factorization was automatic, we update it - // with a new factorization based on the matrix dimensions in units - // of micropanels. - if ( auto_factor ) - { - // In the block-panel algorithm, the m dimension is parallelized - // with ic_nt and the n dimension is parallelized with jc_nt. - bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); - - // Update the ways of parallelism for the jc and ic loops, and then - // update the current thread's root thrinfo_t node according to the - // new ways of parallelism value for the jc loop. - bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); - bli_l3_sup_thrinfo_update_root( rntm, thread ); - } - - /*Enable packing for B matrix for higher sizes*/ - if(bli_is_float(dt) && (n_threads==1)) { - if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_b( 1, rntm ); - } - - bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); - } - else - { - // This branch handles: - // - rrr rrc rcr crr for column-preferential kernels - // - rcc crc ccr ccc for row-preferential kernels - // - Currently only row-preferential kernels are only supported. - const dim_t mu = n / MR; // the n becomes m after a transposition - const dim_t nu = m / NR; // the m becomes n after a transposition - - if ( auto_factor ) - { - // In the block-panel algorithm, the m dimension is parallelized - // with ic_nt and the n dimension is parallelized with jc_nt. - bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); - - // Update the ways of parallelism for the jc and ic loops, and then - // update the current thread's root thrinfo_t node according to the - // new ways of parallelism value for the jc loop. - bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); - bli_l3_sup_thrinfo_update_root( rntm, thread ); - } - - /* Enable packing for B matrix for higher sizes. Note that pack A - * becomes pack B inside var2m because this is transpose case*/ - if(bli_is_float(dt) && (n_threads==1)) { - if((m > 240) && (k > 240) && (n > 240)) - bli_rntm_set_pack_a( 1, rntm ); - } - - bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, - alpha, a, b, beta, c, - stor_id, cntx, rntm, thread ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); - return BLIS_SUCCESS; - -#else // #ifdef BLIS_CONFIG_EPYC - const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); // Don't use the small/unpacked implementation if one of the matrices @@ -335,8 +221,6 @@ err_t bli_gemmsup_int // Return success so that the caller knows that we computed the solution. AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) return BLIS_SUCCESS; - -#endif } // ----------------------------------------------------------------------------- @@ -401,15 +285,9 @@ err_t bli_gemmtsup_int // Decide which algorithm to use (block-panel var2m or panel-block // var1n) based on the number of micropanels in the m and n dimensions. // Also, recalculate the automatic thread factorization. -#ifdef BLIS_CONFIG_EPYC - if ( mu >= nu ) use_bp = TRUE; - else /* if ( mu < nu ) */ use_bp = TRUE;// var1n is not implemented for GEMMT - -#else if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; -#endif // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. @@ -472,14 +350,10 @@ err_t bli_gemmtsup_int // Decide which algorithm to use (block-panel var2m or panel-block // var1n) based on the number of micropanels in the m and n dimensions. // Also, recalculate the automatic thread factorization. -#ifdef BLIS_CONFIG_EPYC - if ( mu >= nu ) use_bp = TRUE; - else /* if ( mu < nu ) */ use_bp = TRUE; //var1n is not implemented for gemmt -#else + if ( mu >= nu ) use_bp = TRUE; else /* if ( mu < nu ) */ use_bp = FALSE; -#endif // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units // of micropanels. diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c new file mode 100644 index 0000000000..7bd44266d2 --- /dev/null +++ b/frame/3/bli_l3_sup_int_amd.c @@ -0,0 +1,352 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +err_t bli_gemmsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); + + const num_t dt = bli_obj_dt( c ); + const dim_t m = bli_obj_length( c ); + const dim_t n = bli_obj_width( c ); + const dim_t k = bli_obj_width( a ); + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + + dim_t jc_new; + dim_t ic_new; + + + //bli_gemmsup_ref_var2 + //bli_gemmsup_ref_var1 + #if 0 + bli_gemmsup_ref_var1n + #else + #endif + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + if ( is_rrr_rrc_rcr_crr ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + // - Currently only row-preferential kernels are only supported. + + // calculate number of micropanels in m and n dimensions and + // recalculate the automatic thread factorization based on these number of micropanels + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + /*Enable packing for B matrix for higher sizes*/ + if(bli_is_float(dt) && (n_threads==1)) { + if((m > 240) && (k > 240) && (n > 240)) + bli_rntm_set_pack_b( 1, rntm ); + } + + bli_gemmsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + // - Currently only row-preferential kernels are only supported. + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + if ( auto_factor ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + /* Enable packing for B matrix for higher sizes. Note that pack A + * becomes pack B inside var2m because this is transpose case*/ + if(bli_is_float(dt) && (n_threads==1)) { + if((m > 240) && (k > 240) && (n > 240)) + bli_rntm_set_pack_a( 1, rntm ); + } + + bli_gemmsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4); + return BLIS_SUCCESS; + + +} + +// ----------------------------------------------------------------------------- + +err_t bli_gemmtsup_int + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + thrinfo_t* thread + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4); +// AOCL_DTL_LOG_GEMMT_INPUTS(AOCL_DTL_LEVEL_TRACE_4, alpha, a, b, beta, c); + + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + + // Don't use the small/unpacked implementation if one of the matrices + // uses general stride. + if ( stor_id == BLIS_XXX ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_4, "SUP doesn't support general stide."); + return BLIS_FAILURE; + } + + const bool is_rrr_rrc_rcr_crr = ( stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR ); + const bool is_rcc_crc_ccr_ccc = !is_rrr_rrc_rcr_crr; + + const num_t dt = bli_obj_dt( c ); + const bool row_pref = bli_cntx_l3_sup_ker_prefers_rows_dt( dt, stor_id, cntx ); + + const bool is_primary = ( row_pref ? is_rrr_rrc_rcr_crr + : is_rcc_crc_ccr_ccc ); + + const dim_t m = bli_obj_length( c ); + const dim_t n = m; + const dim_t MR = bli_cntx_get_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_blksz_def_dt( dt, BLIS_NR, cntx ); + const bool auto_factor = bli_rntm_auto_factor( rntm ); + const dim_t n_threads = bli_rntm_num_threads( rntm ); + bool use_bp = TRUE; + dim_t jc_new; + dim_t ic_new; + + + if ( is_primary ) + { + // This branch handles: + // - rrr rrc rcr crr for row-preferential kernels + // - rcc crc ccr ccc for column-preferential kernels + + const dim_t mu = m / MR; + const dim_t nu = n / NR; + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = TRUE;// var1n is not implemented for GEMMT + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m primary\n" ); + #endif + // block-panel macrokernel; m -> mc, mr; n -> nc, nr: var2() + bli_gemmtsup_ref_var2m( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n primary\n" ); + #endif + // panel-block macrokernel; m -> nc*,mr; n -> mc*,nr: var1() + bli_gemmtsup_ref_var1n( BLIS_NO_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of nc up to be a multiple of mr. + } + } + else + { + // This branch handles: + // - rrr rrc rcr crr for column-preferential kernels + // - rcc crc ccr ccc for row-preferential kernels + + const dim_t mu = n / MR; // the n becomes m after a transposition + const dim_t nu = m / NR; // the m becomes n after a transposition + + // Decide which algorithm to use (block-panel var2m or panel-block + // var1n) based on the number of micropanels in the m and n dimensions. + // Also, recalculate the automatic thread factorization. + + if ( mu >= nu ) use_bp = TRUE; + else /* if ( mu < nu ) */ use_bp = TRUE; //var1n is not implemented for gemmt + + // If the parallel thread factorization was automatic, we update it + // with a new factorization based on the matrix dimensions in units + // of micropanels. + if ( auto_factor ) + { + if ( use_bp ) + { + // In the block-panel algorithm, the m dimension is parallelized + // with ic_nt and the n dimension is parallelized with jc_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &ic_new, &jc_new ); + } + else // if ( !use_bp ) + { + // In the panel-block algorithm, the m dimension is parallelized + // with jc_nt and the n dimension is parallelized with ic_nt. + bli_thread_partition_2x2( n_threads, mu, nu, &jc_new, &ic_new ); + } + + // Update the ways of parallelism for the jc and ic loops, and then + // update the current thread's root thrinfo_t node according to the + // new ways of parallelism value for the jc loop. + bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); + bli_l3_sup_thrinfo_update_root( rntm, thread ); + } + + + if ( use_bp ) + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var2m non-primary\n" ); + #endif + // panel-block macrokernel; m -> nc, nr; n -> mc, mr: var2() + trans + bli_gemmtsup_ref_var2m( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + } + else // use_pb + { + #ifdef TRACEVAR + if ( bli_thread_am_ochief( thread ) ) + printf( "bli_l3_sup_int(): var1n non-primary\n" ); + #endif + // block-panel macrokernel; m -> mc*,nr; n -> nc*,mr: var1() + trans + bli_gemmtsup_ref_var1n( BLIS_TRANSPOSE, + alpha, a, b, beta, c, + stor_id, cntx, rntm, thread ); + // *requires nudging of mc up to be a multiple of nr. + } + } + + // Return success so that the caller knows that we computed the solution. + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return BLIS_SUCCESS; +} + diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index c782559167..d19d2eaea3 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -176,20 +176,7 @@ void bli_gemm_front dim_t m_dim_local = bli_obj_length( &c_local ); dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width_after_trans( &a_local ); -#ifdef BLIS_CONFIG_EPYC - // Regression observed in sgemm native path in cases where m >= 4 * n - // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit - // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for - // the issue. - if( bli_obj_is_float( &c_local ) && - ( n_dim_local >= 1024 ) && - ( k_dim_local >= 1024 ) && - ( m_dim_local >= ( 4 * n_dim_local ) ) ) - { - m_dim_local *= 2; - } -#endif + dim_t k_dim_local = bli_obj_width( &a_local ); // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c new file mode 100644 index 0000000000..41af62007c --- /dev/null +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -0,0 +1,413 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_gemm_front + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); + bli_init_once(); + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_gemm_check( alpha, a, b, beta, c, cntx ); + + // If C has a zero dimension, return early. + if ( bli_obj_has_zero_dim( c ) ) + { + return; + } + + // If alpha is zero, or if A or B has a zero dimension, scale C by beta + // and return early. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) || + bli_obj_has_zero_dim( a ) || + bli_obj_has_zero_dim( b ) ) + { + bli_scalm( beta, c ); + return; + } + +#ifdef BLIS_ENABLE_SMALL_MATRIX + // Only handle small problems separately for homogeneous datatypes. + if ( bli_obj_dt( a ) == bli_obj_dt( b ) && + bli_obj_dt( a ) == bli_obj_dt( c ) && + bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) + { + err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); + return; + } + } +#endif + + // Alias A, B, and C in case we need to apply transformations. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( c, &c_local ); + +#ifdef BLIS_ENABLE_GEMM_MD + cntx_t cntx_local; + + // If any of the storage datatypes differ, or if the computation precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: If we ever want to support the caller setting the computation + // domain explicitly, we will need to check the computation dt against the + // storage dt of C (instead of the computation precision against the + // storage precision of C). + if ( bli_obj_dt( &c_local ) != bli_obj_dt( &a_local ) || + bli_obj_dt( &c_local ) != bli_obj_dt( &b_local ) || + bli_obj_comp_prec( &c_local ) != bli_obj_prec( &c_local ) ) + { + // Handle mixed datatype cases in bli_gemm_md(), which may modify + // the objects or the context. (If the context is modified, cntx + // is adjusted to point to cntx_local.) + bli_gemm_md( &a_local, &b_local, beta, &c_local, &cntx_local, &cntx ); + } + //else // homogeneous datatypes +#endif + + // Load the pack schemas from the context and embed them into the objects + // for A and B. (Native contexts are initialized with the correct pack + // schemas, as are contexts for 1m, and if necessary bli_gemm_md() would + // have made a copy and modified the schemas, so reading them from the + // context should be a safe bet at this point.) This is a sort of hack for + // communicating the desired pack schemas to bli_gemm_cntl_create() (via + // bli_l3_thread_decorator() and bli_l3_cntl_create_if()). This allows us + // to subsequently access the schemas from the control tree, which + // hopefully reduces some confusion, particularly in bli_packm_init(). + const pack_t schema_a = bli_cntx_schema_a_block( cntx ); + const pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Next, we handle the possibility of needing to typecast alpha to the + // computation datatype and/or beta to the storage datatype of C. + + // Attach alpha to B, and in the process typecast alpha to the target + // datatype of the matrix (which in this case is equal to the computation + // datatype). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, alpha, &b_local ); + + // Attach beta to C, and in the process typecast beta to the target + // datatype of the matrix (which in this case is equal to the storage + // datatype of C). + bli_obj_scalar_attach( BLIS_NO_CONJUGATE, beta, &c_local ); + + // Change the alpha and beta pointers to BLIS_ONE since the values have + // now been typecast and attached to the matrices above. + alpha = &BLIS_ONE; + beta = &BLIS_ONE; + +#ifdef BLIS_ENABLE_GEMM_MD + // Don't perform the following optimization for ccr or crc cases, as + // those cases are sensitive to the ukernel storage preference (ie: + // transposing the operation would break them). + if ( !bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) ) +#endif + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_obj_swap( &a_local, &b_local ); + + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + + // We must also swap the pack schemas, which were set by bli_gemm_md() + // or the inlined code above. + bli_obj_swap_pack_schemas( &a_local, &b_local ); + } + + dim_t m_dim_local = bli_obj_length( &c_local ); + dim_t n_dim_local = bli_obj_width( &c_local ); + dim_t k_dim_local = bli_obj_width( &a_local ); + + // Regression observed in sgemm native path in cases where m >= 4 * n + // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit + // 11dfc176a3c422729f453f6c23204cf023e9954d. Temporary workaround for + // the issue. + if( bli_obj_is_float( &c_local ) && + ( n_dim_local >= 1024 ) && + ( k_dim_local >= 1024 ) && + ( m_dim_local >= ( 4 * n_dim_local ) ) ) + { + m_dim_local *= 2; + } + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_GEMM, + BLIS_LEFT, // ignored for gemm/hemm/symm + m_dim_local, + n_dim_local, + k_dim_local, + rntm + ); + + obj_t* cp = &c_local; + obj_t* betap = beta; + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If any of the following conditions are met, create a temporary matrix + // conformal to C into which we will accumulate the matrix product: + // - the storage precision of C differs from the computation precision; + // - the domains are mixed as crr; + // - the storage format of C does not match the preferred orientation + // of the ccr or crc cases. + // Then, after the computation is complete, this matrix will be copied + // or accumulated back to C. + const bool is_ccr_mismatch = + ( bli_gemm_md_is_ccr( &a_local, &b_local, &c_local ) && + !bli_obj_is_col_stored( &c_local ) ); + const bool is_crc_mismatch = + ( bli_gemm_md_is_crc( &a_local, &b_local, &c_local ) && + !bli_obj_is_row_stored( &c_local ) ); + + obj_t ct; + bool use_ct = FALSE; + + // FGVZ: Consider adding another guard here that only creates and uses a + // temporary matrix for accumulation if k < c * kc, where c is some small + // constant like 2. And don't forget to use the same conditional for the + // castm() and free() at the end. + if ( + bli_obj_prec( &c_local ) != bli_obj_comp_prec( &c_local ) || + bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) || + is_ccr_mismatch || + is_crc_mismatch + ) + { + use_ct = TRUE; + } + + // If we need a temporary matrix conformal to C for whatever reason, + // we create it and prepare to use it now. + if ( use_ct ) + { + const dim_t m = bli_obj_length( &c_local ); + const dim_t n = bli_obj_width( &c_local ); + inc_t rs = bli_obj_row_stride( &c_local ); + inc_t cs = bli_obj_col_stride( &c_local ); + + num_t dt_ct = bli_obj_domain( &c_local ) | + bli_obj_comp_prec( &c_local ); + + // When performing the crr case, accumulate to a contiguously-stored + // real matrix so we do not have to repeatedly update C with general + // stride. + if ( bli_gemm_md_is_crr( &a_local, &b_local, &c_local ) ) + dt_ct = BLIS_REAL | bli_obj_comp_prec( &c_local ); + + // When performing the mismatched ccr or crc cases, now is the time + // to specify the appropriate storage so the gemm_md_c2r_ref() virtual + // microkernel can output directly to C (instead of using a temporary + // microtile). + if ( is_ccr_mismatch ) { rs = 1; cs = m; } + else if ( is_crc_mismatch ) { rs = n; cs = 1; } + + bli_obj_create( dt_ct, m, n, rs, cs, &ct ); + + const num_t dt_exec = bli_obj_exec_dt( &c_local ); + const num_t dt_comp = bli_obj_comp_dt( &c_local ); + + bli_obj_set_target_dt( dt_ct, &ct ); + bli_obj_set_exec_dt( dt_exec, &ct ); + bli_obj_set_comp_dt( dt_comp, &ct ); + + // A naive approach would cast C to the comptuation datatype, + // compute with beta, and then cast the result back to the + // user-provided output matrix. However, we employ a different + // approach that halves the number of memops on C (or its + // typecast temporary) by writing the A*B product directly to + // temporary storage, and then using xpbym to scale the + // output matrix by beta and accumulate/cast the A*B product. + //bli_castm( &c_local, &ct ); + betap = &BLIS_ZERO; + + cp = &ct; + } +#endif +#endif + + // Invoke the internal back-end via the thread handler. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_GEMM, // operation family id + alpha, + &a_local, + &b_local, + betap, + cp, + cntx, + rntm, + cntl + ); + +#ifdef BLIS_ENABLE_GEMM_MD +#ifdef BLIS_ENABLE_GEMM_MD_EXTRA_MEM + // If we created a temporary matrix conformal to C for whatever reason, + // we copy/accumulate the result back to C and then release the object. + if ( use_ct ) + { + obj_t beta_local; + + bli_obj_scalar_detach( &c_local, &beta_local ); + + //bli_castnzm( &ct, &c_local ); + bli_xpbym( &ct, &beta_local, &c_local ); + + bli_obj_free( &ct ); + } +#endif +#endif + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); +} + +// ----------------------------------------------------------------------------- + +#if 0 + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + const bool a_is_real = bli_obj_is_real( a ); + const bool a_is_comp = bli_obj_is_complex( a ); + const bool b_is_real = bli_obj_is_real( b ); + const bool b_is_comp = bli_obj_is_complex( b ); + const bool c_is_real = bli_obj_is_real( c ); + const bool c_is_comp = bli_obj_is_complex( c ); + + const bool a_is_single = bli_obj_is_single_prec( a ); + const bool a_is_double = bli_obj_is_double_prec( a ); + const bool b_is_single = bli_obj_is_single_prec( b ); + const bool b_is_double = bli_obj_is_double_prec( b ); + const bool c_is_single = bli_obj_is_single_prec( c ); + const bool c_is_double = bli_obj_is_double_prec( c ); + + const bool comp_single = bli_obj_comp_prec( c ) == BLIS_SINGLE_PREC; + const bool comp_double = bli_obj_comp_prec( c ) == BLIS_DOUBLE_PREC; + + const bool mixeddomain = bli_obj_domain( c ) != bli_obj_domain( a ) || + bli_obj_domain( c ) != bli_obj_domain( b ); + + ( void )a_is_real; ( void )a_is_comp; + ( void )b_is_real; ( void )b_is_comp; + ( void )c_is_real; ( void )c_is_comp; + ( void )a_is_single; ( void )a_is_double; + ( void )b_is_single; ( void )b_is_double; + ( void )c_is_single; ( void )c_is_double; + ( void )comp_single; ( void )comp_double; + + if ( + //( c_is_comp && a_is_comp && b_is_real ) || + //( c_is_comp && a_is_real && b_is_comp ) || + //( c_is_real && a_is_comp && b_is_comp ) || + //( c_is_comp && a_is_real && b_is_real ) || + //( c_is_real && a_is_comp && b_is_real ) || + //( c_is_real && a_is_real && b_is_comp ) || + //FALSE + TRUE + ) + { + if ( + ( c_is_single && a_is_single && b_is_single && mixeddomain ) || + ( c_is_single && a_is_single && b_is_single && comp_single ) || + ( c_is_single && a_is_single && b_is_single && comp_double ) || + ( c_is_single && a_is_single && b_is_double ) || + ( c_is_single && a_is_double && b_is_single ) || + ( c_is_double && a_is_single && b_is_single ) || + ( c_is_single && a_is_double && b_is_double ) || + ( c_is_double && a_is_single && b_is_double ) || + ( c_is_double && a_is_double && b_is_single ) || + ( c_is_double && a_is_double && b_is_double && comp_single ) || + ( c_is_double && a_is_double && b_is_double && comp_double ) || + ( c_is_double && a_is_double && b_is_double && mixeddomain ) || + FALSE + ) + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + } + else + bli_gemm_md_zgemm( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#else +#if 0 + // If any of the storage datatypes differ, or if the execution precision + // differs from the storage precision of C, utilize the mixed datatype + // code path. + // NOTE: We could check the exec dt against the storage dt of C, but for + // now we don't support the caller setting the execution domain + // explicitly. + if ( bli_obj_dt( a ) != bli_obj_dt( b ) || + bli_obj_dt( a ) != bli_obj_dt( c ) || + bli_obj_comp_prec( c ) != bli_obj_prec( c ) ) + { + bli_gemm_md_front( alpha, a, b, beta, c, cntx, cntl ); + return; + } +#endif +#endif + diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index 4b3837544f..db698e9d0f 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -459,6 +459,25 @@ bool bli_cpuid_is_bulldozer return TRUE; } +bool bli_cpuid_is_avx_supported( void ) +{ + uint32_t family, model, features; + + // Call the CPUID instruction and parse its results into a family id, + // model id, and a feature bit field. The return value encodes the + // vendor. + bli_cpuid_query( &family, &model, &features ); + + // Check for expected CPU features. + const uint32_t expected = FEATURE_AVX | + FEATURE_FMA3 | + FEATURE_AVX2; + + if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + + return TRUE; +} + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) arch_t bli_cpuid_query_id( void ) diff --git a/frame/base/bli_cpuid.h b/frame/base/bli_cpuid.h index 62c05ad5ca..47b584c883 100644 --- a/frame/base/bli_cpuid.h +++ b/frame/base/bli_cpuid.h @@ -132,7 +132,7 @@ BLIS_INLINE bool bli_cpuid_has_features( uint32_t have, uint32_t want ) void get_cpu_name( char *cpu_name ); int vpu_count( void ); - +bool bli_cpuid_is_avx_supported(void); enum { @@ -159,6 +159,8 @@ enum FEATURE_AVX512VL = 0x4000 }; + + #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) char* find_string_in( char* target, char* buffer, size_t buf_len, char* filepath ); diff --git a/frame/compat/bla_amax.c b/frame/compat/bla_amax.c index fabed6e72d..b1cf77e7b8 100644 --- a/frame/compat/bla_amax.c +++ b/frame/compat/bla_amax.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,211 +98,5 @@ f77_int PASTEF772(i,chx,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -f77_int isamax_ - ( - const f77_int* n, - const float* x, const f77_int* incx - ) -{ - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); - - dim_t n0; - float* x0; - inc_t incx0; - gint_t bli_index; - f77_int f77_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( *n < 1 || *incx <= 0 ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "isamax_: vector empty"); - return 0; - } - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_samaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - } - else - { - PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) - ( - n0, - x0, incx0, - &bli_index, - NULL, - NULL - ); - } - - /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) - index. Also, if the BLAS integer size differs from the BLIS - integer size, that typecast occurs here. */ - f77_index = bli_index + 1; - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return f77_index; -} - -f77_int idamax_ - ( - const f77_int* n, - const double* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx); - - dim_t n0; - double* x0; - inc_t incx0; - gint_t bli_index; - f77_int f77_index; - - /* If the vector is empty, return an index of zero. This early check - is needed to emulate netlib BLAS. Without it, bli_?amaxv() will - return 0, which ends up getting incremented to 1 (below) before - being returned, which is not what we want. */ - if ( *n < 1 || *incx <= 0 ) { - AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "idamax_: vector empty"); - return 0; - } - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_damaxv_zen_int - ( - n0, - x0, incx0, - &bli_index, - NULL - ); - } - else - { - PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) - ( - n0, - x0, incx0, - &bli_index, - NULL, - NULL - ); - } - - /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) - index. Also, if the BLAS integer size differs from the BLIS - integer size, that typecast occurs here. */ - f77_index = bli_index + 1; - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return f77_index; -} - -INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) -#else INSERT_GENTFUNC_BLAS( amax, amaxv ) #endif -#endif diff --git a/frame/compat/bla_amax_amd.c b/frame/compat/bla_amax_amd.c new file mode 100644 index 0000000000..7f1a771f7c --- /dev/null +++ b/frame/compat/bla_amax_amd.c @@ -0,0 +1,295 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype_x, chx, blasname, blisname ) \ +\ +f77_int PASTEF772(i,chx,blasname) \ + ( \ + const f77_int* n, \ + const ftype_x* x, const f77_int* incx \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(chx), *n, *incx) \ +\ + dim_t n0; \ + ftype_x* x0; \ + inc_t incx0; \ + gint_t bli_index; \ + f77_int f77_index; \ +\ + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ \ + if ( *n < 1 || *incx <= 0 ) { \ + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "iamax_: vector empty") \ + return 0; \ + }\ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype_x*)x, *incx, x0, incx0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(chx,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + &bli_index, \ + NULL, \ + NULL \ + ); \ +\ + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ \ + f77_index = bli_index + 1; \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + return f77_index; \ +} + +#ifdef BLIS_ENABLE_BLAS + +f77_int isamax_ + ( + const f77_int* n, + const float* x, const f77_int* incx + ) +{ + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx); + + dim_t n0; + float* x0; + inc_t incx0; + gint_t bli_index; + f77_int f77_index; + + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ + if ( *n < 1 || *incx <= 0 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "isamax_: vector empty"); + return 0; + } + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_samaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(s,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ + f77_index = bli_index + 1; + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return f77_index; +} + +f77_int idamax_ + ( + const f77_int* n, + const double* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_AMAX_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx); + + dim_t n0; + double* x0; + inc_t incx0; + gint_t bli_index; + f77_int f77_index; + + /* If the vector is empty, return an index of zero. This early check + is needed to emulate netlib BLAS. Without it, bli_?amaxv() will + return 0, which ends up getting incremented to 1 (below) before + being returned, which is not what we want. */ + if ( *n < 1 || *incx <= 0 ) { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_1, "idamax_: vector empty"); + return 0; + } + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_damaxv_zen_int + ( + n0, + x0, incx0, + &bli_index, + NULL + ); + } + else + { + PASTEMAC2(d,amaxv,BLIS_TAPI_EX_SUF) + ( + n0, + x0, incx0, + &bli_index, + NULL, + NULL + ); + } + + /* Convert zero-based BLIS (C) index to one-based BLAS (Fortran) + index. Also, if the BLAS integer size differs from the BLIS + integer size, that typecast occurs here. */ + f77_index = bli_index + 1; + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return f77_index; +} + +INSERT_GENTFUNC_BLAS_CZ( amax, amaxv ) + +#endif diff --git a/frame/compat/bla_axpy.c b/frame/compat/bla_axpy.c index 41885e95d6..1a30f417b3 100644 --- a/frame/compat/bla_axpy.c +++ b/frame/compat/bla_axpy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,399 +87,6 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void saxpy_ -( - const f77_int* n, - const float* alpha, - const float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float*)alpha, *incx, *incy) - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((float*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_saxpyv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (float*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - - } - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - -void daxpy_ -( - const f77_int* n, - const double* alpha, - const double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((double*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_daxpyv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (double*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -void caxpy_ -( - const f77_int* n, - const scomplex* alpha, - const scomplex* x, const f77_int* incx, - scomplex* y, const f77_int* incy - ) -{ - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_caxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (scomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -void zaxpy_ -( - const f77_int* n, - const dcomplex* alpha, - const dcomplex* x, const f77_int* incx, - dcomplex* y, const f77_int* incy - ) -{ - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) - - /* Initialize BLIS. */ - // bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - bli_zaxpyv_zen_int5 - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL - ); - - } - else - { - PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - (dcomplex*)alpha, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - // bli_finalize_auto(); -} - -#else INSERT_GENTFUNC_BLAS( axpy, axpyv ) -#endif #endif diff --git a/frame/compat/bla_axpy_amd.c b/frame/compat/bla_axpy_amd.c new file mode 100644 index 0000000000..8a9f0280c6 --- /dev/null +++ b/frame/compat/bla_axpy_amd.c @@ -0,0 +1,462 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, (void*)alpha, *incx, *incy) \ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + (ftype*)alpha, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void saxpy_ +( + const f77_int* n, + const float* alpha, + const float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, (float*)alpha, *incx, *incy) + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_saxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(s,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (float*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void daxpy_ +( + const f77_int* n, + const double* alpha, + const double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, (double*)alpha, *incx, *incy) + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_daxpyv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(d,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (double*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void caxpy_ +( + const f77_int* n, + const scomplex* alpha, + const scomplex* x, const f77_int* incx, + scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, (scomplex*)alpha, *incx, *incy) + + /* Initialize BLIS. */ + // bli_init_auto(); + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_caxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(c,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (scomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + +void zaxpy_ +( + const f77_int* n, + const dcomplex* alpha, + const dcomplex* x, const f77_int* incx, + dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_AXPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, (dcomplex*)alpha, *incx, *incy) + + /* Initialize BLIS. */ + // bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_zaxpyv_zen_int5 + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL + ); + + } + else + { + PASTEMAC2(z,axpyv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + (dcomplex*)alpha, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + // bli_finalize_auto(); +} + + + +#endif diff --git a/frame/compat/bla_copy.c b/frame/compat/bla_copy.c index 61df88cf1e..74baba689c 100644 --- a/frame/compat/bla_copy.c +++ b/frame/compat/bla_copy.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -88,211 +88,5 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void scopy_ -( - const f77_int* n, - const float* x, const f77_int* incx, - float* y, const f77_int* incy -) -{ - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if (*n < 0) - n0 = (dim_t)0; - else - n0 = (dim_t)(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if (*incx < 0) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (float*)((x)+(n0 - 1)*(-*incx)); - incx0 = (inc_t)(*incx); - - } - else - { - x0 = (float*)(x); - incx0 = (inc_t)(*incx); - } - - if (*incy < 0) - { - y0 = (y)+(n0 - 1)*(-*incy); - incy0 = (inc_t)(*incy); - - } - else - { - y0 = (y); - incy0 = (inc_t)(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_scopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else - { - PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ -// bli_finalize_auto(); -} - -void dcopy_ -( - const f77_int* n, - const double* x, const f77_int* incx, - double* y, const f77_int* incy -) -{ - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if (*n < 0) - n0 = (dim_t)0; - else - n0 = (dim_t)(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if (*incx < 0) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (double*)((x)+(n0 - 1)*(-*incx)); - incx0 = (inc_t)(*incx); - - } - else - { - x0 = (double*)(x); - incx0 = (inc_t)(*incx); - } - - if (*incy < 0) - { - y0 = (y)+(n0 - 1)*(-*incy); - incy0 = (inc_t)(*incy); - - } - else - { - y0 = (y); - incy0 = (inc_t)(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel */ - bli_dcopyv_zen_int - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else - { - PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - NULL, - NULL - ); - } - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ -// bli_finalize_auto(); -} - -INSERT_GENTFUNC_BLAS_CZ(copy, copyv) -#else INSERT_GENTFUNC_BLAS(copy, copyv) #endif -#endif diff --git a/frame/compat/bla_copy_amd.c b/frame/compat/bla_copy_amd.c new file mode 100644 index 0000000000..8dc4d5287c --- /dev/null +++ b/frame/compat/bla_copy_amd.c @@ -0,0 +1,285 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy) \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv(n0, (ftype*)x, *incx, x0, incx0); \ + bli_convert_blas_incv(n0, (ftype*)y, *incy, y0, incy0); \ + \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch, blisname, BLIS_TAPI_EX_SUF) \ + (\ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ +\ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void scopy_ +( + const f77_int* n, + const float* x, const f77_int* incx, + float* y, const f77_int* incy +) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy) + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (float*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (float*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_scopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(s, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +void dcopy_ +( + const f77_int* n, + const double* x, const f77_int* incx, + double* y, const f77_int* incy +) +{ + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_COPY_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy) + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if (*n < 0) + n0 = (dim_t)0; + else + n0 = (dim_t)(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if (*incx < 0) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (double*)((x)+(n0 - 1)*(-*incx)); + incx0 = (inc_t)(*incx); + + } + else + { + x0 = (double*)(x); + incx0 = (inc_t)(*incx); + } + + if (*incy < 0) + { + y0 = (y)+(n0 - 1)*(-*incy); + incy0 = (inc_t)(*incy); + + } + else + { + y0 = (y); + incy0 = (inc_t)(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel */ + bli_dcopyv_zen_int + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else + { + PASTEMAC2(d, copyv, BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + NULL, + NULL + ); + } + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ +// bli_finalize_auto(); +} + +INSERT_GENTFUNC_BLAS_CZ(copy, copyv) + +#endif diff --git a/frame/compat/bla_dot.c b/frame/compat/bla_dot.c index 2a0f815217..3c4d8c538f 100644 --- a/frame/compat/bla_dot.c +++ b/frame/compat/bla_dot.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -90,663 +90,11 @@ ftype PASTEF772(ch,blasname,chc) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -float sdot_ - ( - const f77_int* n, - const float* x, const f77_int* incx, - const float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - float rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((float*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((float*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_sdotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} - -double ddot_ - ( - const f77_int* n, - const double* x, const f77_int* incx, - const double* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - double rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((double*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((double*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_ddotv_zen_int10 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} -#else INSERT_GENTFUNCDOTR_BLAS( dot, dotv ) -#endif #ifdef BLIS_ENABLE_BLAS #ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL -#ifdef BLIS_CONFIG_EPYC -scomplex cdotu_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - scomplex rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return rho; -} - -dcomplex zdotu_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - dcomplex rho; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_NO_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} - - -scomplex cdotc_ - ( - const f77_int* n, - const scomplex* x, const f77_int* incx, - const scomplex* y, const f77_int* incy - ) -{ - dim_t n0; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - scomplex rho; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((scomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((scomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_cdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} - -dcomplex zdotc_ - ( - const f77_int* n, - const dcomplex* x, const f77_int* incx, - const dcomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); - dim_t n0; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - dcomplex rho; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = ((dcomplex*)x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = ((dcomplex*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen) - { - /* Call BLIS kernel. */ - bli_zdotv_zen_int5 - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - BLIS_CONJUGATE, - BLIS_NO_CONJUGATE, - n0, - x0, incx0, - y0, incy0, - &rho, - NULL, - NULL - ); - } - - - - - - /* Finalize BLIS. */ -// bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - - return rho; -} -#else INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) -#endif #else // For the "intel" complex return type, use a hidden parameter to return the result #undef GENTFUNCDOT @@ -801,8 +149,8 @@ void PASTEF772(ch,blasname,chc) \ } INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) -#endif -#endif +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL +#endif // BLIS_ENABLE_BLAS // -- "Black sheep" dot product function definitions -- @@ -876,4 +224,4 @@ double PASTEF77(d,sdot) return rho; } -#endif +#endif // BLIS_ENABLE_BLAS diff --git a/frame/compat/bla_dot_amd.c b/frame/compat/bla_dot_amd.c new file mode 100644 index 0000000000..0cdaa6535b --- /dev/null +++ b/frame/compat/bla_dot_amd.c @@ -0,0 +1,841 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +ftype PASTEF772(ch,blasname,chc) \ + ( \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return rho; \ +} + +#ifdef BLIS_ENABLE_BLAS +float sdot_ + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + float rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((float*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_sdotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(s,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +double ddot_ + ( + const f77_int* n, + const double* x, const f77_int* incx, + const double* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + double rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((double*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_ddotv_zen_int10 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(d,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +#ifdef BLIS_DISABLE_COMPLEX_RETURN_INTEL +scomplex cdotu_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return rho; +} + +dcomplex zdotu_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_NO_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + + +scomplex cdotc_ + ( + const f77_int* n, + const scomplex* x, const f77_int* incx, + const scomplex* y, const f77_int* incy + ) +{ + dim_t n0; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + scomplex rho; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *n, *incx, *incy); + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((scomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((scomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_cdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +dcomplex zdotc_ + ( + const f77_int* n, + const dcomplex* x, const f77_int* incx, + const dcomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *n, *incx, *incy); + dim_t n0; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + dcomplex rho; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = ((dcomplex*)x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = ((dcomplex*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) + { + /* Call BLIS kernel. */ + bli_zdotv_zen_int5 + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + BLIS_CONJUGATE, + BLIS_NO_CONJUGATE, + n0, + x0, incx0, + y0, incy0, + &rho, + NULL, + NULL + ); + } + + + + + + /* Finalize BLIS. */ +// bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +#else // BLIS_DISABLE_COMPLEX_RETURN_INTEL +// For the "intel" complex return type, use a hidden parameter to return the result +#undef GENTFUNCDOT +#define GENTFUNCDOT( ftype, ch, chc, blis_conjx, blasname, blisname ) \ +\ +void PASTEF772(ch,blasname,chc) \ + ( \ + ftype* rhop, \ + const f77_int* n, \ + const ftype* x, const f77_int* incx, \ + const ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *n, *incx, *incy); \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + ftype rho; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_conjx, \ + BLIS_NO_CONJUGATE, \ + n0, \ + x0, incx0, \ + y0, incy0, \ + &rho, \ + NULL, \ + NULL \ + ); \ +\ + /* Finalize BLIS. */ \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + bli_finalize_auto(); \ +\ + *rhop = rho; \ +} + +INSERT_GENTFUNCDOTC_BLAS( dot, dotv ) +#endif // BLIS_DISABLE_COMPLEX_RETURN_INTEL + + + +// -- "Black sheep" dot product function definitions -- + +// Input vectors stored in single precision, computed in double precision, +// with result returned in single precision. +float PASTEF77(sd,sdot) + ( + const f77_int* n, + const float* sb, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + return ( float ) + ( + ( double )(*sb) + + PASTEF77(d,sdot) + ( + n, + x, incx, + y, incy + ) + ); +} + +// Input vectors stored in single precision, computed in double precision, +// with result returned in double precision. +double PASTEF77(d,sdot) + ( + const f77_int* n, + const float* x, const f77_int* incx, + const float* y, const f77_int* incy + ) +{ + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + double rho; + dim_t i; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_DOTV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + /* Initialization of BLIS is not required. */ + + /* Convert/typecast negative values of n to zero. */ + bli_convert_blas_dim1( *n, n0 ); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + bli_convert_blas_incv( n0, (float*)x, *incx, x0, incx0 ); + bli_convert_blas_incv( n0, (float*)y, *incy, y0, incy0 ); + + rho = 0.0; + + for ( i = 0; i < n0; i++ ) + { + float* chi1 = x0 + (i )*incx0; + float* psi1 = y0 + (i )*incy0; + + bli_ddots( (( double )(*chi1)), + (( double )(*psi1)), rho ); + } + + /* Finalization of BLIS is not required, because initialization was + not required. */ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + + return rho; +} + +#endif diff --git a/frame/compat/bla_gemm.c b/frame/compat/bla_gemm.c index 3cc7845739..406ff69d53 100644 --- a/frame/compat/bla_gemm.c +++ b/frame/compat/bla_gemm.c @@ -300,509 +300,7 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void dgemm_ -( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const double* alpha, - const double* a, const f77_int* lda, - const double* b, const f77_int* ldb, - const double* beta, - double* c, const f77_int* ldc -) -{ - - - - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - - /* Initialize BLIS. */ - bli_init_auto(); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(d), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); - bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1(*m, m0); - bli_convert_blas_dim1(*n, n0); - bli_convert_blas_dim1(*k, k0); - - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (!bamdzen) - { - // This code is duplicated below, however we don't want to move it out of - // this IF block as it will affect the performance on Zen architetures - // Also this is temporary fix which will be replaced later. - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double *)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) - { - bli_dgemm_ref_k1_nn( m0, n0, k0, - (double*)alpha, - (double*)a, *lda, - (double*)b, *ldb, - (double*)beta, - c, *ldc - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - - return; - } - - if (n0 == 1) - { - if (bli_is_notrans(blis_transa)) - { - bli_dgemv_unf_var2( - BLIS_NO_TRANSPOSE, - bli_extract_conj(blis_transb), - m0, k0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - return; - } - else if (m0 == 1) - { - if (bli_is_notrans(blis_transb)) - { - bli_dgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var2( - blis_transb, - bli_extract_conj(blis_transa), - k0, n0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double*)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - //cntx_t* cntx = bli_gks_query_cntx(); - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. - - // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better - // - -#ifdef AOCL_DYNAMIC - if (nt && ((n0 > 10 ) || (k0 > 10)) ) -#else - if (nt) -#endif - { - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - // The code below will be called when number of threads = 1. - -#ifdef BLIS_ENABLE_SMALL_MATRIX - - //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) - if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) - { - err_t status; - if (bli_is_notrans(blis_transa)) - { - status = bli_dgemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - else - { - status = bli_dgemm_small_At ( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } - } - -#endif //#ifdef BLIS_ENABLE_SMALL_MATRIX - - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - // fall back on native path when dgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ - /* ( */ - /* &alphao, */ - /* &ao, */ - /* &bo, */ - /* &betao, */ - /* &co, */ - /* NULL, */ - /* NULL */ - /* ); */ - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); -} // end of dgemm_ - -void zgemm_ - ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* b, const f77_int* ldb, - const dcomplex* beta, - dcomplex* c, const f77_int* ldc - ) -{ - trans_t blis_transa; - trans_t blis_transb; - dim_t m0, n0, k0; - - /* Initialize BLIS. */ - bli_init_auto(); - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // default instance peformance tuning is done in zgemm. - // Single instance tuning is done based on env set. - dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); - - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. - if ( nt ) - { - // Will call parallelized zgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - // The code below will be called when number of threads = 1. -#if ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) - && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) - { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } -#endif//ENABLE_INDUCED_METHOD - -// native tuning resulted in better numbers compared to sup in constrained multi-instance -// sup has been enabled for single instance cases. - if(single_instance==1) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - - } - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); -}// end of zgemm_ - - -INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) -#else INSERT_GENTFUNC_BLAS( gemm,gemm ) -#endif #if 1 void dzgemm_ diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c new file mode 100644 index 0000000000..7ef58bfb35 --- /dev/null +++ b/frame/compat/bla_gemm_amd.c @@ -0,0 +1,894 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// +// Define BLAS-to-BLIS interfaces. +// +#define ENABLE_INDUCED_METHOD 0 +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_char* transb, \ + const f77_int* m, \ + const f77_int* n, \ + const f77_int* k, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* b, const f77_int* ldb, \ + const ftype* beta, \ + ftype* c, const f77_int* ldc \ + ) \ +{ \ +\ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ +\ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ +\ + if( n0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + BLIS_NO_TRANSPOSE, \ + bli_extract_conj(blis_transb), \ + m0, k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a,\ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*) beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transa, \ + bli_extract_conj(blis_transb), \ + k0, m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*)beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ + else if( m0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transb)) \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + n0, k0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + k0, n0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ +\ + const num_t dt = PASTEMAC(ch,type); \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ +\ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ +\ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ +\ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} +#endif + +#ifdef BLIS_ENABLE_BLAS +void dgemm_ +( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc +) +{ + + + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(d), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(d), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); + bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1(*m, m0); + bli_convert_blas_dim1(*n, n0); + bli_convert_blas_dim1(*k, k0); + + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + // This code is duplicated below, however we don't want to move it out of + // this IF block as it will affect the performance on Zen architetures + // Also this is temporary fix which will be replaced later. + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double *)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + { + bli_dgemm_ref_k1_nn( m0, n0, k0, + (double*)alpha, + (double*)a, *lda, + (double*)b, *ldb, + (double*)beta, + c, *ldc + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; + } + + if (n0 == 1) + { + if (bli_is_notrans(blis_transa)) + { + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; + } + else if (m0 == 1) + { + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double*)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double*)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double*)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double*)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + //cntx_t* cntx = bli_gks_query_cntx(); + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. + + // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better + // + +#ifdef AOCL_DYNAMIC + if (nt && ((n0 > 10 ) || (k0 > 10)) ) +#else + if (nt) +#endif + { + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + // The code below will be called when number of threads = 1. + +#ifdef BLIS_ENABLE_SMALL_MATRIX + + //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) + if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || + ((n0 <= 10) && (k0 <=10)) ) + { + err_t status; + if (bli_is_notrans(blis_transa)) + { + status = bli_dgemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else + { + status = bli_dgemm_small_At ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } + } + +#endif //#ifdef BLIS_ENABLE_SMALL_MATRIX + + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + // fall back on native path when dgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ + /* ( */ + /* &alphao, */ + /* &ao, */ + /* &bo, */ + /* &betao, */ + /* &co, */ + /* NULL, */ + /* NULL */ + /* ); */ + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); +} // end of dgemm_ + +void zgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // default instance peformance tuning is done in zgemm. + // Single instance tuning is done based on env set. + dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); + + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. + if ( nt ) + { + // Will call parallelized zgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + // The code below will be called when number of threads = 1. +#if ENABLE_INDUCED_METHOD + /* 3m_sqp is optimal for certain matrix shapes. + Initial study that it works well for square sizes and sizes closer to square shape. + + * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. + * Further investigation is necessary to make the usage choices more generic. */ + bool sqp_on = false; + if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) + { + sqp_on = true; + } + + // current range of sizes used for 3m_sqp to be expaned after evaluation. + if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) + && ( k0 == 1120 ) ) //to be tuned further. + { + sqp_on = true; + } + + if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) + { + //sqp algo is found better for n > 40 + if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + } +#endif//ENABLE_INDUCED_METHOD + +// native tuning resulted in better numbers compared to sup in constrained multi-instance +// sup has been enabled for single instance cases. + if(single_instance==1) + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + } + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of zgemm_ + + +INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) + + +// Observed a regression in dgemm with this function addition. +// Disabling temporarily. +#if 0 +void dzgemm_ + ( + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const dcomplex* alpha, + const double* a, const f77_int* lda, + const dcomplex* b, const f77_int* ldb, + const dcomplex* beta, + dcomplex* c, const f77_int* ldc + ) +{ + + trans_t blis_transa; + trans_t blis_transb; + dim_t m0, n0, k0; + + /* Initialize BLIS. */ + bli_init_auto(); + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + const num_t dt_a = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); +}// end of dzgemm_ +#endif +#endif diff --git a/frame/compat/bla_gemv.c b/frame/compat/bla_gemv.c index e9b210bbc1..9dba1b43c4 100644 --- a/frame/compat/bla_gemv.c +++ b/frame/compat/bla_gemv.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -147,844 +147,5 @@ void PASTEF77(ch,blasname) \ #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC -void dgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - const double* x, const f77_int* incx, - const double* beta, - double* y, const f77_int* incy - ) -{ - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(d), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) - { - m_y = m0; - n_x = n0; - } - else - { - m_y = n0; - n_x = m0; - } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - x0 = ((double*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((double*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((double*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((double*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) - { - //variant_2 is chosen for column-storage - // and uses axpyf-based implementation - bli_dgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - else - { - //var_1 is chosen for row-storage - //and uses dotxf-based implementation - bli_dgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (double*)alpha, - (double*)a, rs_a, cs_a, - x0, incx0, - (double*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - -void sgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - const float* x, const f77_int* incx, - const float* beta, - float* y, const f77_int* incy - ) -{ - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(s), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - //bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if ( *m < 0 ) m0 = ( dim_t )0; - else m0 = ( dim_t )(*m); - - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if ( bli_does_notrans( blis_transa ) ) - { - m_y = m0; - n_x = n0; - } - else - { - m_y = n0; - n_x = m0; - } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - if ( m_y > 0 && n_x == 0 ) - { - /* Finalize BLIS. */ - // bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - x0 = ((float*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((float*)x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((float*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((float*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Call variants based on transpose value. */ - if(bli_does_notrans(blis_transa)) - { - bli_sgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_sgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (float*)alpha, - (float*)a, rs_a, cs_a, - x0, incx0, - (float*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -void cgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - const scomplex* x, const f77_int* incx, - const scomplex* beta, - scomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - scomplex* x0; - scomplex* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(c), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - // bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if( *m < 0 ) m0 = (dim_t)0; - else m0 = (dim_t)(*m); - - if( *n < 0 ) n0 = (dim_t)0; - else n0 = (dim_t)(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - - if ( m_y > 0 && n_x == 0 ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if( *incx < 0 ) - { - x0 = ((scomplex*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((scomplex*)x); - incx0 = (inc_t)(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((scomplex*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((scomplex*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if( m_y == 1 ) - { - conj_t conja = bli_extract_conj(blis_transa); - scomplex rho; - if (bamdzen) - { - bli_cdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL, - NULL - ); - } - - scomplex yval = *y0; - if(!bli_ceq0(*beta)) - { - bli_cscals( *beta, yval ); - } - else - { - bli_csetsc( 0.0, 0.0, &yval); - } - if(!bli_ceq0(*alpha)) - { - bli_caxpys( *alpha, rho, yval); - } - y0->real = yval.real; - y0->imag = yval.imag; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* call variants based on transpose value */ - if( bli_does_notrans( blis_transa ) ) - { - bli_cgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_cgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - x0, incx0, - (scomplex*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -void zgemv_ - ( - const f77_char* transa, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - const dcomplex* x, const f77_int* incx, - const dcomplex* beta, - dcomplex* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); - AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); - - trans_t blis_transa; - dim_t m0, n0; - dim_t m_y, n_x; - dcomplex* x0; - dcomplex* y0; - inc_t incx0; - inc_t incy0; - inc_t rs_a, cs_a; - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemv) - ( - MKSTR(z), - MKSTR(gemv), - transa, - m, - n, - lda, - incx, - incy - ); - - if (*m == 0 || *n == 0) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; - else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; - else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; - else - { - // See comment for bli_param_map_netlib_to_blis_side() above. - // bli_check_error_code( BLIS_INVALID_TRANS ); - blis_transa = BLIS_NO_TRANSPOSE; - } - - /* Convert/typecast negative values of m and n to zero. */ - if( *m < 0 ) m0 = (dim_t)0; - else m0 = (dim_t)(*m); - - if( *n < 0 ) n0 = (dim_t)0; - else n0 = (dim_t)(*n); - - /* Determine the dimensions of x and y so we can adjust the increments, - if necessary.*/ - if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } - else { m_y = n0; n_x = m0; } - - /* BLAS handles cases where trans(A) has no columns, and x has no elements, - in a peculiar way. In these situations, BLAS returns without performing - any action, even though most sane interpretations of gemv would have the - the operation reduce to y := beta * y. Here, we catch those cases that - BLAS would normally mishandle and emulate the BLAS exactly so as to - provide "bug-for-bug" compatibility. Note that this extreme level of - compatibility would not be as much of an issue if it weren't for the - fact that some BLAS test suites actually test for these cases. Also, it - should be emphasized that BLIS, if called natively, does NOT exhibit - this quirky behavior; it will scale y by beta, as one would expect. */ - - if ( m_y > 0 && n_x == 0 ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if( *incx < 0 ) - { - x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); - incx0 = ( inc_t )(*incx); - } - else - { - x0 = ((dcomplex*)x); - incx0 = (inc_t)(*incx); - } - - if ( *incy < 0 ) - { - y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); - incy0 = ( inc_t )(*incy); - } - else - { - y0 = ((dcomplex*)y); - incy0 = ( inc_t )(*incy); - } - - /* Set the row and column strides of A. */ - rs_a = 1; - cs_a = *lda; - - // When dynamic dispatch is enabled i.e. library is built for ‘amdzen’ configuration. - // This function is invoked on all architectures including ‘generic’. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if( m_y == 1 ) - { - conj_t conja = bli_extract_conj(blis_transa); - dcomplex rho; - - if (bamdzen) - { - bli_zdotv_zen_int5 - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL - ); - } - else - { - /* Call BLIS interface. */ - PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) - ( - conja, - BLIS_NO_CONJUGATE, - n_x, - (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, - x0, incx0, - &rho, - NULL, - NULL - ); - } - - dcomplex yval = *y0; - if(!bli_zeq0(*beta)) - { - bli_zscals( *beta, yval ); - } - else - { - bli_zsetsc( 0.0, 0.0, &yval); - } - if(!bli_zeq0(*alpha)) - { - bli_zaxpys( *alpha, rho, yval); - } - y0->real = yval.real; - y0->imag = yval.imag; - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - if (bamdzen == 0) - { - /* Call BLIS interface. */ - PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* call variants based on transpose value */ - if( bli_does_notrans( blis_transa ) ) - { - bli_zgemv_unf_var2 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL - ); - } - else - { - bli_zgemv_unf_var1 - ( - blis_transa, - BLIS_NO_CONJUGATE, - m0, - n0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - x0, incx0, - (dcomplex*)beta, - y0, incy0, - NULL - ); - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); -} - - -#else INSERT_GENTFUNC_BLAS( gemv, gemv ) #endif -#endif diff --git a/frame/compat/bla_gemv_amd.c b/frame/compat/bla_gemv_amd.c new file mode 100644 index 0000000000..354f45fe1b --- /dev/null +++ b/frame/compat/bla_gemv_amd.c @@ -0,0 +1,963 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* transa, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + const ftype* x, const f77_int* incx, \ + const ftype* beta, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); \ + trans_t blis_transa; \ + dim_t m0, n0; \ + dim_t m_y, n_x; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ + inc_t rs_a, cs_a; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + m, \ + n, \ + lda, \ + incx, \ + incy \ + ); \ +\ + if (*m == 0 || *n == 0) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return; \ + } \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ +\ + /* Convert/typecast negative values of m and n to zero. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ \ + bli_set_dims_with_trans( blis_transa, m0, n0, &m_y, &n_x ); \ +\ + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ \ + if ( m_y > 0 && n_x == 0 ) \ + { \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +\ + return; \ + } \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n_x, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( m_y, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Set the row and column strides of A. */ \ + rs_a = 1; \ + cs_a = *lda; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + BLIS_NO_CONJUGATE, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + x0, incx0, \ + (ftype*)beta, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + + +#ifdef BLIS_ENABLE_BLAS +void dgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + const double* x, const f77_int* incx, + const double* beta, + double* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(d), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((double*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((double*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((double*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((double*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(d,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + //variant_2 is chosen for column-storage + // and uses axpyf-based implementation + bli_dgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + else + { + //var_1 is chosen for row-storage + //and uses dotxf-based implementation + bli_dgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (double*)alpha, + (double*)a, rs_a, cs_a, + x0, incx0, + (double*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + +void sgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + const float* x, const f77_int* incx, + const float* beta, + float* y, const f77_int* incy + ) +{ + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(s), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if ( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if ( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if ( *transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + //bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if ( *m < 0 ) m0 = ( dim_t )0; + else m0 = ( dim_t )(*m); + + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if ( bli_does_notrans( blis_transa ) ) + { + m_y = m0; + n_x = n0; + } + else + { + m_y = n0; + n_x = m0; + } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + if ( m_y > 0 && n_x == 0 ) + { + /* Finalize BLIS. */ + // bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + x0 = ((float*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((float*)x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((float*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((float*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(s,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Call variants based on transpose value. */ + if(bli_does_notrans(blis_transa)) + { + bli_sgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_sgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (float*)alpha, + (float*)a, rs_a, cs_a, + x0, incx0, + (float*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + +void cgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + const scomplex* x, const f77_int* incx, + const scomplex* beta, + scomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'C', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + scomplex* x0; + scomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(c), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((scomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((scomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((scomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((scomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + scomplex rho; + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_cdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(c,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (scomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + scomplex yval = *y0; + if(!bli_ceq0(*beta)) + { + bli_cscals( *beta, yval ); + } + else + { + bli_csetsc( 0.0, 0.0, &yval); + } + if(!bli_ceq0(*alpha)) + { + bli_caxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(c,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_cgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_cgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + x0, incx0, + (scomplex*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + +void zgemv_ + ( + const f77_char* transa, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + const dcomplex* x, const f77_int* incx, + const dcomplex* beta, + dcomplex* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); + AOCL_DTL_LOG_GEMV_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'Z', *transa, *m, *n, (void*)alpha, *lda, *incx, (void*)beta, *incy); + + trans_t blis_transa; + dim_t m0, n0; + dim_t m_y, n_x; + dcomplex* x0; + dcomplex* y0; + inc_t incx0; + inc_t incy0; + inc_t rs_a, cs_a; + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(gemv) + ( + MKSTR(z), + MKSTR(gemv), + transa, + m, + n, + lda, + incx, + incy + ); + + if (*m == 0 || *n == 0) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + if( *transa == 'n' || *transa == 'N' ) blis_transa = BLIS_NO_TRANSPOSE; + else if( *transa == 't' || *transa == 'T' ) blis_transa = BLIS_TRANSPOSE; + else if( * transa == 'c' || *transa == 'C' ) blis_transa = BLIS_CONJ_TRANSPOSE; + else + { + // See comment for bli_param_map_netlib_to_blis_side() above. + // bli_check_error_code( BLIS_INVALID_TRANS ); + blis_transa = BLIS_NO_TRANSPOSE; + } + + /* Convert/typecast negative values of m and n to zero. */ + if( *m < 0 ) m0 = (dim_t)0; + else m0 = (dim_t)(*m); + + if( *n < 0 ) n0 = (dim_t)0; + else n0 = (dim_t)(*n); + + /* Determine the dimensions of x and y so we can adjust the increments, + if necessary.*/ + if( bli_does_notrans( blis_transa ) ) { m_y = m0, n_x = n0; } + else { m_y = n0; n_x = m0; } + + /* BLAS handles cases where trans(A) has no columns, and x has no elements, + in a peculiar way. In these situations, BLAS returns without performing + any action, even though most sane interpretations of gemv would have the + the operation reduce to y := beta * y. Here, we catch those cases that + BLAS would normally mishandle and emulate the BLAS exactly so as to + provide "bug-for-bug" compatibility. Note that this extreme level of + compatibility would not be as much of an issue if it weren't for the + fact that some BLAS test suites actually test for these cases. Also, it + should be emphasized that BLIS, if called natively, does NOT exhibit + this quirky behavior; it will scale y by beta, as one would expect. */ + + if ( m_y > 0 && n_x == 0 ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if( *incx < 0 ) + { + x0 = ((dcomplex*)x) + (n_x-1)*(-*incx); + incx0 = ( inc_t )(*incx); + } + else + { + x0 = ((dcomplex*)x); + incx0 = (inc_t)(*incx); + } + + if ( *incy < 0 ) + { + y0 = ((dcomplex*)y) + (m_y-1)*(-*incy); + incy0 = ( inc_t )(*incy); + } + else + { + y0 = ((dcomplex*)y); + incy0 = ( inc_t )(*incy); + } + + /* Set the row and column strides of A. */ + rs_a = 1; + cs_a = *lda; + + if( m_y == 1 ) + { + conj_t conja = bli_extract_conj(blis_transa); + dcomplex rho; + + if (bli_cpuid_is_avx_supported() == TRUE) + { + bli_zdotv_zen_int5 + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL + ); + } + else + { + /* Call BLIS interface. */ + PASTEMAC2(z,dotv,BLIS_TAPI_EX_SUF) + ( + conja, + BLIS_NO_CONJUGATE, + n_x, + (dcomplex*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, + x0, incx0, + &rho, + NULL, + NULL + ); + } + + dcomplex yval = *y0; + if(!bli_zeq0(*beta)) + { + bli_zscals( *beta, yval ); + } + else + { + bli_zsetsc( 0.0, 0.0, &yval); + } + if(!bli_zeq0(*alpha)) + { + bli_zaxpys( *alpha, rho, yval); + } + y0->real = yval.real; + y0->imag = yval.imag; + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + if (bli_cpuid_is_avx_supported() == FALSE) + { + /* Call BLIS interface. */ + PASTEMAC2(z,gemv,BLIS_TAPI_EX_SUF) + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* call variants based on transpose value */ + if( bli_does_notrans( blis_transa ) ) + { + bli_zgemv_unf_var2 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + else + { + bli_zgemv_unf_var1 + ( + blis_transa, + BLIS_NO_CONJUGATE, + m0, + n0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + x0, incx0, + (dcomplex*)beta, + y0, incy0, + NULL + ); + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); +} + + + +#endif diff --git a/frame/compat/bla_scal.c b/frame/compat/bla_scal.c index 30fd857bc7..b9651577eb 100644 --- a/frame/compat/bla_scal.c +++ b/frame/compat/bla_scal.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -93,171 +93,5 @@ void PASTEF772(chx,cha,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void sscal_ - ( - const f77_int* n, - const float* alpha, - float* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', (void *) alpha, *n, *incx ); - dim_t n0; - float* x0; - inc_t incx0; - /* Initialize BLIS. */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - bli_sscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (float *)alpha, - x0, incx0, - NULL - ); - } - else{ - PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE,\ - n0, \ - (float *)alpha,\ - x0, incx0,\ - NULL, \ - NULL \ - );\ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void dscal_ - ( - const f77_int* n, - const double* alpha, - double* x, const f77_int* incx - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', (void *)alpha, *n, *incx ); - dim_t n0; - double* x0; - inc_t incx0; - - /* Initialize BLIS */ - //bli_init_auto(); - - if (*n == 0 || alpha == NULL) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - return; - } - - /* Convert typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen){ - bli_dscalv_zen_int10 - ( - BLIS_NO_CONJUGATE, - n0, - (double*) alpha, - x0, incx0, - NULL - ); - } - else{ - PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ - ( \ - BLIS_NO_CONJUGATE,\ - n0, \ - (double *)alpha,\ - x0, incx0,\ - NULL, \ - NULL \ - );\ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) -#else INSERT_GENTFUNCSCAL_BLAS( scal, scalv ) #endif -#endif diff --git a/frame/compat/bla_scal_amd.c b/frame/compat/bla_scal_amd.c new file mode 100644 index 0000000000..178776a149 --- /dev/null +++ b/frame/compat/bla_scal_amd.c @@ -0,0 +1,260 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNCSCAL +#define GENTFUNCSCAL( ftype_x, ftype_a, chx, cha, blasname, blisname ) \ +\ +void PASTEF772(chx,cha,blasname) \ + ( \ + const f77_int* n, \ + const ftype_a* alpha, \ + ftype_x* x, const f77_int* incx \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + dim_t n0; \ + ftype_x* x0; \ + inc_t incx0; \ + ftype_x alpha_cast; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + if (*n == 0 || alpha == NULL) { \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); \ + return ; \ + } \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype_x*)x, *incx, x0, incx0 ); \ +\ + /* NOTE: We do not natively implement BLAS's csscal/zdscal in BLIS. + that is, we just always sub-optimally implement those cases + by casting alpha to ctype_x (potentially the complex domain) and + using the homogeneous datatype instance according to that type. */ \ + PASTEMAC2(cha,chx,copys)( *alpha, alpha_cast ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(chx,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE, \ + n0, \ + &alpha_cast, \ + x0, incx0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void sscal_ + ( + const f77_int* n, + const float* alpha, + float* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', (void *) alpha, *n, *incx ); + dim_t n0; + float* x0; + inc_t incx0; + /* Initialize BLIS. */ + //bli_init_auto(); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + bli_sscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (float *)alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(s,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (float *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void dscal_ + ( + const f77_int* n, + const double* alpha, + double* x, const f77_int* incx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SCAL_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', (void *)alpha, *n, *incx ); + dim_t n0; + double* x0; + inc_t incx0; + + /* Initialize BLIS */ + //bli_init_auto(); + + if (*n == 0 || alpha == NULL) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + return; + } + + /* Convert typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE){ + bli_dscalv_zen_int10 + ( + BLIS_NO_CONJUGATE, + n0, + (double*) alpha, + x0, incx0, + NULL + ); + } + else{ + PASTEMAC2(d,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + BLIS_NO_CONJUGATE,\ + n0, \ + (double *)alpha,\ + x0, incx0,\ + NULL, \ + NULL \ + );\ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNCSCAL_BLAS_CZ( scal, scalv ) + +#endif diff --git a/frame/compat/bla_swap.c b/frame/compat/bla_swap.c index 6ecb360f95..d653426478 100644 --- a/frame/compat/bla_swap.c +++ b/frame/compat/bla_swap.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -83,190 +83,5 @@ void PASTEF77(ch,blasname) \ } #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void sswap_ - ( - const f77_int* n, - float* x, const f77_int* incx, - float* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); - dim_t n0; - float* x0; - float* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = (y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = (y); - incy0 = ( inc_t )(*incy); - } - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { -/* Call BLIS kernel */ - bli_sswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else{ - PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ - ( \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -void dswap_ - ( - const f77_int* n, - double* x, const f77_int* incx, - double* y, const f77_int* incy - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) - AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); - dim_t n0; - double* x0; - double* y0; - inc_t incx0; - inc_t incy0; - - /* Initialize BLIS. */ -// bli_init_auto(); - - /* Convert/typecast negative values of n to zero. */ - if ( *n < 0 ) n0 = ( dim_t )0; - else n0 = ( dim_t )(*n); - - /* If the input increments are negative, adjust the pointers so we can - use positive increments instead. */ - if ( *incx < 0 ) - { - /* The semantics of negative stride in BLAS are that the vector - operand be traversed in reverse order. (Another way to think - of this is that negative strides effectively reverse the order - of the vector, but without any explicit data movements.) This - is also how BLIS interprets negative strides. The differences - is that with BLAS, the caller *always* passes in the 0th (i.e., - top-most or left-most) element of the vector, even when the - stride is negative. By contrast, in BLIS, negative strides are - used *relative* to the vector address as it is given. Thus, in - BLIS, if this backwards traversal is desired, the caller *must* - pass in the address to the (n-1)th (i.e., the bottom-most or - right-most) element along with a negative stride. */ - - x0 = (x) + (n0-1)*(-*incx); - incx0 = ( inc_t )(*incx); - - } - else - { - x0 = (x); - incx0 = ( inc_t )(*incx); - } - - if ( *incy < 0 ) - { - y0 = (y) + (n0-1)*(-*incy); - incy0 = ( inc_t )(*incy); - - } - else - { - y0 = (y); - incy0 = ( inc_t )(*incy); - } - - - /* Call BLIS kernel */ - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { - bli_dswapv_zen_int8 - ( - n0, - x0, incx0, - y0, incy0, - NULL - ); - } - else{ - PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ - ( \ - n0, \ - x0, incx0, \ - y0, incy0, \ - NULL, \ - NULL \ - ); \ - } - - /* Finalize BLIS. */ -// bli_finalize_auto(); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) -} - -INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) - -#else INSERT_GENTFUNC_BLAS( swap, swapv ) #endif -#endif diff --git a/frame/compat/bla_swap_amd.c b/frame/compat/bla_swap_amd.c new file mode 100644 index 0000000000..617c78a4aa --- /dev/null +++ b/frame/compat/bla_swap_amd.c @@ -0,0 +1,268 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_int* n, \ + ftype* x, const f77_int* incx, \ + ftype* y, const f77_int* incy \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) \ + dim_t n0; \ + ftype* x0; \ + ftype* y0; \ + inc_t incx0; \ + inc_t incy0; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Convert/typecast negative values of n to zero. */ \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ \ + bli_convert_blas_incv( n0, (ftype*)x, *incx, x0, incx0 ); \ + bli_convert_blas_incv( n0, (ftype*)y, *incy, y0, incy0 ); \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#ifdef BLIS_ENABLE_BLAS + +void sswap_ + ( + const f77_int* n, + float* x, const f77_int* incx, + float* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'S', *n, *incx, *incy); + dim_t n0; + float* x0; + float* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = (y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = (y); + incy0 = ( inc_t )(*incy); + } + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + /* Call BLIS kernel */ + bli_sswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(s,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +void dswap_ + ( + const f77_int* n, + double* x, const f77_int* incx, + double* y, const f77_int* incy + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) + AOCL_DTL_LOG_SWAP_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'D', *n, *incx, *incy); + dim_t n0; + double* x0; + double* y0; + inc_t incx0; + inc_t incy0; + + /* Initialize BLIS. */ +// bli_init_auto(); + + /* Convert/typecast negative values of n to zero. */ + if ( *n < 0 ) n0 = ( dim_t )0; + else n0 = ( dim_t )(*n); + + /* If the input increments are negative, adjust the pointers so we can + use positive increments instead. */ + if ( *incx < 0 ) + { + /* The semantics of negative stride in BLAS are that the vector + operand be traversed in reverse order. (Another way to think + of this is that negative strides effectively reverse the order + of the vector, but without any explicit data movements.) This + is also how BLIS interprets negative strides. The differences + is that with BLAS, the caller *always* passes in the 0th (i.e., + top-most or left-most) element of the vector, even when the + stride is negative. By contrast, in BLIS, negative strides are + used *relative* to the vector address as it is given. Thus, in + BLIS, if this backwards traversal is desired, the caller *must* + pass in the address to the (n-1)th (i.e., the bottom-most or + right-most) element along with a negative stride. */ + + x0 = (x) + (n0-1)*(-*incx); + incx0 = ( inc_t )(*incx); + + } + else + { + x0 = (x); + incx0 = ( inc_t )(*incx); + } + + if ( *incy < 0 ) + { + y0 = (y) + (n0-1)*(-*incy); + incy0 = ( inc_t )(*incy); + + } + else + { + y0 = (y); + incy0 = ( inc_t )(*incy); + } + + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + bli_dswapv_zen_int8 + ( + n0, + x0, incx0, + y0, incy0, + NULL + ); + } + else{ + PASTEMAC2(d,swapv,BLIS_TAPI_EX_SUF) \ + ( \ + n0, \ + x0, incx0, \ + y0, incy0, \ + NULL, \ + NULL \ + ); \ + } + + /* Finalize BLIS. */ +// bli_finalize_auto(); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) +} + +INSERT_GENTFUNC_BLAS_CZ( swap, swapv ) + + +#endif diff --git a/frame/compat/bla_trsm.c b/frame/compat/bla_trsm.c index 654d3530d2..fea7ba6f17 100644 --- a/frame/compat/bla_trsm.c +++ b/frame/compat/bla_trsm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -380,1167 +380,5 @@ void PASTEF77(ch,blasname) \ #endif #ifdef BLIS_ENABLE_BLAS -#ifdef BLIS_CONFIG_EPYC - -void strsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const float* alpha, - const float* a, const f77_int* lda, - float* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(s), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_FLOAT; - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_strsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (float*)alpha, - (float*)a, rs_a, cs_a, - (float*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_strsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (float*)alpha, - (float*)a, rs_a, cs_a, - (float*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - /* b = alpha * b; */ - bli_sscalv_ex - ( - conja, - m0, - (float*)alpha, - b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - float inva = 1.0/ *a; - for(dim_t indx = 0; indx < m0; indx ++) - { - b[indx] = ( inva * b[indx] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_strsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (float*)alpha, - (float*)a, cs_a, rs_a, - (float*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_strsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (float*)alpha, - (float*)a, cs_a, rs_a, - (float*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - /* b = alpha * b; */ - bli_sscalv_ex - ( - conja, - n0, - (float*)alpha, - b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - float inva = 1.0/ *a; - for(dim_t indx = 0; indx < n0; indx ++) - { - b[indx*cs_b] = (inva * b[indx*cs_b] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_strsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - } - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} - -void dtrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const double* alpha, - const double* a, const f77_int* lda, - double* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE ; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(d), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_DOUBLE; - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_dtrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_dtrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - /* b = alpha * b; */ - bli_dscalv_ex - ( - conja, - m0, - (double*)alpha, - b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - double inva = 1.0/ *a; - for(dim_t indx = 0; indx < m0; indx ++) - { - b[indx] = ( inva * b[indx] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_dtrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (double*)alpha, - (double*)a, cs_a, rs_a, - (double*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_dtrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (double*)alpha, - (double*)a, cs_a, rs_a, - (double*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - /* b = alpha * b; */ - bli_dscalv_ex - ( - conja, - n0, - (double*)alpha, - b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - double inva = 1.0/ *a; - for(dim_t indx = 0; indx < n0; indx ++) - { - b[indx*cs_b] = (inva * b[indx*cs_b] ); - } - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - if (bamdzen) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - } - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#if 0 -void ztrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const dcomplex* alpha, - const dcomplex* a, const f77_int* lda, - dcomplex* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(z), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_DCOMPLEX; - - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_ztrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - (dcomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_ztrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (dcomplex*)alpha, - (dcomplex*)a, rs_a, cs_a, - (dcomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - bli_zscalv_ex - ( - conja, - m0, - (dcomplex*)alpha, - (dcomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {1.0, 0.0}; - dcomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < m0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zscals(inva, b[indx]) -#else - - bli_zinvscals(inva, b[indx]) -#endif - } - - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ztrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (dcomplex*)alpha, - (dcomplex*)a, cs_a, rs_a, - (dcomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ztrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (dcomplex*)alpha, - (dcomplex*)a, cs_a, rs_a, - (dcomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - bli_zscalv_ex - ( - conja, - n0, - (dcomplex*)alpha, - (dcomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - dcomplex inva = {1.0, 0.0}; - dcomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < n0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_zscals(inva ,b[indx * cs_b]) -#else - - bli_zinvscals(inva ,b[indx * cs_b]) -#endif - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); - -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=500 && n0<=500) || - (nt && (m0+n0)<128) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#endif -#if 0 -void ctrsm_ -( - const f77_char* side, - const f77_char* uploa, - const f77_char* transa, - const f77_char* diaga, - const f77_int* m, - const f77_int* n, - const scomplex* alpha, - const scomplex* a, const f77_int* lda, - scomplex* b, const f77_int* ldb -) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', - *side, *uploa,*transa, *diaga, *m, *n, - (void*)alpha,*lda, *ldb); - - side_t blis_side; - uplo_t blis_uploa; - trans_t blis_transa; - diag_t blis_diaga; - dim_t m0, n0; - conj_t conja = BLIS_NO_CONJUGATE; - - /* Initialize BLIS. */ - bli_init_auto(); - - /* Perform BLAS parameter checking. */ - PASTEBLACHK(trsm) - ( - MKSTR(c), - MKSTR(trsm), - side, - uploa, - transa, - diaga, - m, - n, - lda, - ldb - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_side( *side, &blis_side ); - bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const num_t dt = BLIS_SCOMPLEX; - - - if( n0 == 1 ) - { - if( blis_side == BLIS_LEFT ) - { - if(bli_is_notrans(blis_transa)) - { - bli_ctrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - (scomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - bli_ctrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - m0, - (scomplex*)alpha, - (scomplex*)a, rs_a, cs_a, - (scomplex*)b, rs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) - { - bli_cscalv_ex - ( - conja, - m0, - (scomplex*)alpha, - (scomplex*)b, rs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - scomplex inva = {1.0, 0.0}; - scomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cinvscals(a_dup, inva); -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - - for(dim_t indx = 0; indx < m0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cscals(inva ,b[indx]) -#else - bli_cinvscals(inva, b[indx]) -#endif - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - - } - } - else if( m0 == 1 ) - { - if(blis_side == BLIS_RIGHT) - { - if(bli_is_notrans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ctrsv_unf_var1 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (scomplex*)alpha, - (scomplex*)a, cs_a, rs_a, - (scomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - else if(bli_is_trans(blis_transa)) - { - if(blis_uploa == BLIS_UPPER) - blis_uploa = BLIS_LOWER; - else - blis_uploa = BLIS_UPPER; - - bli_ctrsv_unf_var2 - ( - blis_uploa, - blis_transa, - blis_diaga, - n0, - (scomplex*)alpha, - (scomplex*)a, cs_a, rs_a, - (scomplex*)b, cs_b, - NULL - ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) - { - bli_cscalv_ex - ( - conja, - n0, - (scomplex*)alpha, - (scomplex*)b, cs_b, - NULL, - NULL - ); - if(blis_diaga == BLIS_NONUNIT_DIAG) - { - scomplex inva = {1.0, 0.0}; - scomplex a_dup; - /** - * For conjugate transpose and non-unit diagonal - * kernel, negating imaginary part of A. - * As the dimension of A is 1x1, there's going to - * be only one 1 element of A. - */ - if(*transa == 'C' && *diaga == 'N') - { - a_dup.real = a->real; - a_dup.imag = a->imag * -1.0; - } - else - { - a_dup.real = a->real; - a_dup.imag = a->imag; - } - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cinvscals(a_dup, inva) -#else - inva.real = a_dup.real; - inva.imag = a_dup.imag; -#endif - for(dim_t indx = 0; indx < n0; indx ++) - { -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - bli_cscals(inva ,b[indx * cs_b]) -#else - bli_cinvscals(inva, b[indx * cs_b]) -#endif - - } - } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return; - } - } - - const struc_t struca = BLIS_TRIANGULAR; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - - dim_t mn0_a; - - bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); - - bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); - - bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); - - bli_obj_set_uplo( blis_uploa, &ao ); - bli_obj_set_diag( blis_diaga, &ao ); - bli_obj_set_conjtrans( blis_transa, &ao ); - - bli_obj_set_struc( struca, &ao ); -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } -#endif - bli_trsmnat - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) - /* Finalize BLIS. */ - bli_finalize_auto(); -} -#endif -INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) -#else INSERT_GENTFUNC_BLAS( trsm, trsm ) #endif -#endif diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c new file mode 100644 index 0000000000..21b2a1598d --- /dev/null +++ b/frame/compat/bla_trsm_amd.c @@ -0,0 +1,1544 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + + +// +// Define BLAS-to-BLIS interfaces. +// + +#ifdef BLIS_BLAS3_CALLS_TAPI + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ +\ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ +\ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_side, \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#else + +#undef GENTFUNC +#define GENTFUNC( ftype, ch, blasname, blisname ) \ +\ +void PASTEF77(ch,blasname) \ + ( \ + const f77_char* side, \ + const f77_char* uploa, \ + const f77_char* transa, \ + const f77_char* diaga, \ + const f77_int* m, \ + const f77_int* n, \ + const ftype* alpha, \ + const ftype* a, const f77_int* lda, \ + ftype* b, const f77_int* ldb \ + ) \ +{ \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) \ + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *side, *uploa, \ + *transa, *diaga, *m, *n, (void*)alpha, *lda, *ldb); \ + side_t blis_side; \ + uplo_t blis_uploa; \ + trans_t blis_transa; \ + diag_t blis_diaga; \ + dim_t m0, n0; \ + ftype a_conj; \ + conj_t conja = BLIS_NO_CONJUGATE ; \ +\ + /* Initialize BLIS. */ \ + bli_init_auto(); \ +\ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + side, \ + uploa, \ + transa, \ + diaga, \ + m, \ + n, \ + lda, \ + ldb \ + ); \ +\ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); \ + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); \ +\ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ +\ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const num_t dt = PASTEMAC(ch,type); \ +\ + /* ----------------------------------------------------------- */ \ + /* TRSM API: AX = B, where X = B */ \ + /* CALL TRSV when X & B are vector and when A is Matrix */ \ + /* Case 1: LEFT : TRSM, B(mxn) = A(mxm) * X(mxn) */ \ + /* Case 2: RIGHT : TRSM, B(mxn) = X(mxn) * A(nxn) */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | A | X | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | LEFT | mxm | mxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mxm | mx1 | mx1 | TRSV | */ \ + /* | m = 1 | 1x1 | 1xn | 1xn | INVSCALS | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | | X | A | B | Implementation | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | RIGHT | mxn | nxn | mxn | | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* | n = 1 | mx1 | 1x1 | mx1 | Transpose and INVSCALS| */ \ + /* | m = 1 | 1xn | nxn | 1xn | Transpose and TRSV | */ \ + /* |--------|-------|-------|-------|------------------------| */ \ + /* If Transpose(A) uplo = lower then uplo = higher */ \ + /* If Transpose(A) uplo = higher then uplo = lower */ \ + /* ----------------------------------------------------------- */ \ +\ + if( n0 == 1 ) \ + { \ + if( blis_side == BLIS_LEFT ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch, trsv_unf_var2) \ + ( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, \ + NULL \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + else if(bli_is_trans(blis_transa)) \ + { \ + PASTEMAC(ch, trsv_unf_var1) \ + ( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, \ + NULL \ + ); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) \ + { \ + /* b = alpha * b; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + conja, \ + m0, \ + (ftype*)alpha, \ + b, rs_b, \ + NULL, \ + NULL \ + ); \ + if(blis_diaga == BLIS_NONUNIT_DIAG) \ + { \ + conja = bli_extract_conj( blis_transa ); \ + PASTEMAC(ch,copycjs)( conja, *a, a_conj ); \ + for(int indx = 0; indx < m0; indx ++) \ + { \ + PASTEMAC(ch,invscals)( a_conj, b[indx] ); \ + } \ + }\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if( m0 == 1 ) \ + { \ + if(blis_side == BLIS_RIGHT) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + if(blis_uploa == BLIS_UPPER) \ + blis_uploa = BLIS_LOWER; \ + else \ + blis_uploa = BLIS_UPPER; \ + PASTEMAC(ch, trsv_unf_var1)( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, cs_a, rs_a, \ + (ftype*)b, cs_b, \ + NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + else if(bli_is_trans(blis_transa)) \ + { \ + if(blis_uploa == BLIS_UPPER) \ + blis_uploa = BLIS_LOWER; \ + else \ + blis_uploa = BLIS_UPPER; \ + PASTEMAC(ch, trsv_unf_var2)( \ + blis_uploa, \ + blis_transa, \ + blis_diaga, \ + n0, \ + (ftype*)alpha, \ + (ftype*)a, cs_a, rs_a, \ + (ftype*)b, cs_b, \ + NULL); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) \ + { \ + /* b = alpha * b; */ \ + PASTEMAC2(ch,scalv,BLIS_TAPI_EX_SUF) \ + ( \ + conja, \ + n0, \ + (ftype*)alpha, \ + b, cs_b, \ + NULL, \ + NULL \ + ); \ + if(blis_diaga == BLIS_NONUNIT_DIAG) \ + { \ + conja = bli_extract_conj( blis_transa ); \ + PASTEMAC(ch,copycjs)( conja, *a, a_conj ); \ + for(int indx = 0; indx < n0; indx ++) \ + { \ + PASTEMAC(ch,invscals)( a_conj, b[indx*cs_b] ); \ + }\ + } \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + return; \ + } \ + } \ +\ + const struc_t struca = BLIS_TRIANGULAR; \ +\ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ +\ + dim_t mn0_a; \ +\ + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); \ +\ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ +\ + bli_obj_init_finish( dt, mn0_a, mn0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)b, rs_b, cs_b, &bo ); \ +\ + bli_obj_set_uplo( blis_uploa, &ao ); \ + bli_obj_set_diag( blis_diaga, &ao ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ +\ + bli_obj_set_struc( struca, &ao ); \ +\ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + blis_side, \ + &alphao, \ + &ao, \ + &bo, \ + NULL, \ + NULL \ + ); \ +\ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ +} + +#endif + +#ifdef BLIS_ENABLE_BLAS + +void strsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const float* alpha, + const float* a, const f77_int* lda, + float* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(s), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_FLOAT; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (float*)alpha, + (float*)a, rs_a, cs_a, + (float*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + m0, + (float*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(dim_t indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_strsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (float*)alpha, + (float*)a, cs_a, rs_a, + (float*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_sscalv_ex + ( + conja, + n0, + (float*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + float inva = 1.0/ *a; + for(dim_t indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (float*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (float*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (float*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_strsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + } + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} + +void dtrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const double* alpha, + const double* a, const f77_int* lda, + double* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE ; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(d), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DOUBLE; + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + m0, + (double*)alpha, + b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(dim_t indx = 0; indx < m0; indx ++) + { + b[indx] = ( inva * b[indx] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_dtrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (double*)alpha, + (double*)a, cs_a, rs_a, + (double*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + /* b = alpha * b; */ + bli_dscalv_ex + ( + conja, + n0, + (double*)alpha, + b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + double inva = 1.0/ *a; + for(dim_t indx = 0; indx < n0; indx ++) + { + b[indx*cs_b] = (inva * b[indx*cs_b] ); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (double*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (double*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_dtrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + } + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} +#if 0 +void ztrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const dcomplex* alpha, + const dcomplex* a, const f77_int* lda, + dcomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'z', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(z), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_DCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (dcomplex*)alpha, + (dcomplex*)a, rs_a, cs_a, + (dcomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + bli_zscalv_ex + ( + conja, + m0, + (dcomplex*)alpha, + (dcomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva, b[indx]) +#else + + bli_zinvscals(inva, b[indx]) +#endif + } + + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ztrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (dcomplex*)alpha, + (dcomplex*)a, cs_a, rs_a, + (dcomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + bli_zscalv_ex + ( + conja, + n0, + (dcomplex*)alpha, + (dcomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + dcomplex inva = {1.0, 0.0}; + dcomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_zscals(inva ,b[indx * cs_b]) +#else + + bli_zinvscals(inva ,b[indx * cs_b]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=500 && n0<=500) || + (nt && (m0+n0)<128) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} +#endif +#if 0 +void ctrsm_ +( + const f77_char* side, + const f77_char* uploa, + const f77_char* transa, + const f77_char* diaga, + const f77_int* m, + const f77_int* n, + const scomplex* alpha, + const scomplex* a, const f77_int* lda, + scomplex* b, const f77_int* ldb +) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', + *side, *uploa,*transa, *diaga, *m, *n, + (void*)alpha,*lda, *ldb); + + side_t blis_side; + uplo_t blis_uploa; + trans_t blis_transa; + diag_t blis_diaga; + dim_t m0, n0; + conj_t conja = BLIS_NO_CONJUGATE; + + /* Initialize BLIS. */ + bli_init_auto(); + + /* Perform BLAS parameter checking. */ + PASTEBLACHK(trsm) + ( + MKSTR(c), + MKSTR(trsm), + side, + uploa, + transa, + diaga, + m, + n, + lda, + ldb + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_side( *side, &blis_side ); + bli_param_map_netlib_to_blis_uplo( *uploa, &blis_uploa ); + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_diag( *diaga, &blis_diaga ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const num_t dt = BLIS_SCOMPLEX; + + + if( n0 == 1 ) + { + if( blis_side == BLIS_LEFT ) + { + if(bli_is_notrans(blis_transa)) + { + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + m0, + (scomplex*)alpha, + (scomplex*)a, rs_a, cs_a, + (scomplex*)b, rs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if( ( blis_side == BLIS_RIGHT ) && ( m0 != 1 ) ) + { + bli_cscalv_ex + ( + conja, + m0, + (scomplex*)alpha, + (scomplex*)b, rs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva); +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + + for(dim_t indx = 0; indx < m0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx]) +#else + bli_cinvscals(inva, b[indx]) +#endif + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + + } + } + else if( m0 == 1 ) + { + if(blis_side == BLIS_RIGHT) + { + if(bli_is_notrans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var1 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + else if(bli_is_trans(blis_transa)) + { + if(blis_uploa == BLIS_UPPER) + blis_uploa = BLIS_LOWER; + else + blis_uploa = BLIS_UPPER; + + bli_ctrsv_unf_var2 + ( + blis_uploa, + blis_transa, + blis_diaga, + n0, + (scomplex*)alpha, + (scomplex*)a, cs_a, rs_a, + (scomplex*)b, cs_b, + NULL + ); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + else if(( blis_side == BLIS_LEFT ) && ( n0 != 1 )) + { + bli_cscalv_ex + ( + conja, + n0, + (scomplex*)alpha, + (scomplex*)b, cs_b, + NULL, + NULL + ); + if(blis_diaga == BLIS_NONUNIT_DIAG) + { + scomplex inva = {1.0, 0.0}; + scomplex a_dup; + /** + * For conjugate transpose and non-unit diagonal + * kernel, negating imaginary part of A. + * As the dimension of A is 1x1, there's going to + * be only one 1 element of A. + */ + if(*transa == 'C' && *diaga == 'N') + { + a_dup.real = a->real; + a_dup.imag = a->imag * -1.0; + } + else + { + a_dup.real = a->real; + a_dup.imag = a->imag; + } + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cinvscals(a_dup, inva) +#else + inva.real = a_dup.real; + inva.imag = a_dup.imag; +#endif + for(dim_t indx = 0; indx < n0; indx ++) + { +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + bli_cscals(inva ,b[indx * cs_b]) +#else + bli_cinvscals(inva, b[indx * cs_b]) +#endif + + } + } + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return; + } + } + + const struc_t struca = BLIS_TRIANGULAR; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + + dim_t mn0_a; + + bli_set_dim_with_side( blis_side, m0, n0, &mn0_a ); + + bli_obj_init_finish_1x1( dt, (scomplex*)alpha, &alphao ); + + bli_obj_init_finish( dt, mn0_a, mn0_a, (scomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0, n0, (scomplex*)b, rs_b, cs_b, &bo ); + + bli_obj_set_uplo( blis_uploa, &ao ); + bli_obj_set_diag( blis_diaga, &ao ); + bli_obj_set_conjtrans( blis_transa, &ao ); + + bli_obj_set_struc( struca, &ao ); +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) + { + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif + bli_trsmnat + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO) + /* Finalize BLIS. */ + bli_finalize_auto(); +} +#endif +INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) + +#endif diff --git a/kernels/zen/1/bli_scalv_zen_int10.c b/kernels/zen/1/bli_scalv_zen_int10.c index de9d8339d3..7146e86879 100644 --- a/kernels/zen/1/bli_scalv_zen_int10.c +++ b/kernels/zen/1/bli_scalv_zen_int10.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017 - 2022, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -64,16 +64,7 @@ void bli_sscalv_zen_int10 if ( PASTEMAC(s,eq0)( *alpha ) ) { float* zero = bli_s0; -#ifdef BLIS_CONFIG_EPYC - bli_ssetv_zen_int - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); -#else + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); ssetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SETV_KER, cntx ); f ( @@ -83,7 +74,7 @@ void bli_sscalv_zen_int10 x, incx, cntx ); -#endif + return; } @@ -342,16 +333,7 @@ void bli_dscalv_zen_int10 if ( PASTEMAC(d,eq0)( *alpha ) ) { double* zero = bli_d0; -#ifdef BLIS_CONFIG_EPYC - bli_dsetv_zen_int - ( - BLIS_NO_CONJUGATE, - n, - zero, - x, incx, - cntx - ); -#else + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); dsetv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SETV_KER, cntx ); f @@ -362,7 +344,7 @@ void bli_dscalv_zen_int10 x, incx, cntx ); -#endif + return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_4.c b/kernels/zen/1f/bli_axpyf_zen_int_4.c index f5a043db84..bb24e6c52f 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_4.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_4.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -95,29 +95,6 @@ void bli_caxpyf_zen_int_4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - scomplex* a1 = a + (0 )*inca + (i )*lda; - scomplex* chi1 = x + (i )*incx; - scomplex* y1 = y + (0 )*incy; - scomplex alpha_chi1; - - bli_ccopycjs( conjx, *chi1, alpha_chi1 ); - bli_cscals( *alpha, alpha_chi1 ); - - bli_caxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -141,7 +118,6 @@ void bli_caxpyf_zen_int_4 ); } -#endif return; } @@ -357,28 +333,6 @@ void bli_zaxpyf_zen_int_4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } -#else zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -402,7 +356,6 @@ void bli_zaxpyf_zen_int_4 ); } -#endif return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index 1125197775..d09a85f57f 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -108,29 +108,6 @@ void bli_saxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - float* a1 = a + (0 )*inca + (i )*lda; - float* chi1 = x + (i )*incx; - float* y1 = y + (0 )*incy; - float alpha_chi1; - - bli_scopycjs( conjx, *chi1, alpha_chi1 ); - bli_sscals( *alpha, alpha_chi1 ); - - bli_saxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -154,7 +131,6 @@ void bli_saxpyf_zen_int_5 ); } -#endif return; } @@ -382,29 +358,6 @@ void bli_daxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -428,7 +381,6 @@ void bli_daxpyf_zen_int_5 ); } -#endif return; } @@ -655,29 +607,6 @@ static void bli_daxpyf_zen_int_16x2 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -701,7 +630,6 @@ static void bli_daxpyf_zen_int_16x2 ); } -#endif return; } @@ -966,43 +894,21 @@ void bli_daxpyf_zen_int_16x4 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - if(b_n & 2) - { - bli_daxpyf_zen_int_16x2( conja, - conjx, - m, 2, - alpha, a, inca, lda, - x, incx, - y, incy, - cntx - ); - b_n -= 2; - a += 2*lda; - x += 2 * incx; - } - for ( i = 0; i < b_n; ++i ) - { - double* a1 = a + (0 )*inca + (i )*lda; - double* chi1 = x + (i )*incx; - double* y1 = y + (0 )*incy; - double alpha_chi1; - - bli_dcopycjs( conjx, *chi1, alpha_chi1 ); - bli_dscals( *alpha, alpha_chi1 ); - - bli_daxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } + if (b_n & 2) + { + bli_daxpyf_zen_int_16x2( conja, + conjx, + m, 2, + alpha, a, inca, lda, + x, incx, + y, incy, + cntx + ); + b_n -= 2; + a += 2*lda; + x += 2 * incx; + } -#else daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1026,7 +932,6 @@ void bli_daxpyf_zen_int_16x4 ); } -#endif return; } @@ -1396,29 +1301,6 @@ void bli_caxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - scomplex* a1 = a + (0 )*inca + (i )*lda; - scomplex* chi1 = x + (i )*incx; - scomplex* y1 = y + (0 )*incy; - scomplex alpha_chi1; - - bli_ccopycjs( conjx, *chi1, alpha_chi1 ); - bli_cscals( *alpha, alpha_chi1 ); - - bli_caxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else caxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_SCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1442,7 +1324,6 @@ void bli_caxpyf_zen_int_5 ); } -#endif return; } @@ -1810,29 +1691,6 @@ void bli_zaxpyf_zen_int_5 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - dcomplex* a1 = a + (0 )*inca + (i )*lda; - dcomplex* chi1 = x + (i )*incx; - dcomplex* y1 = y + (0 )*incy; - dcomplex alpha_chi1; - - bli_zcopycjs( conjx, *chi1, alpha_chi1 ); - bli_zscals( *alpha, alpha_chi1 ); - - bli_zaxpyv_zen_int5 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } - -#else zaxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DCOMPLEX, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -1855,8 +1713,7 @@ void bli_zaxpyf_zen_int_5 cntx ); } - -#endif + return; } diff --git a/kernels/zen/1f/bli_axpyf_zen_int_6.c b/kernels/zen/1f/bli_axpyf_zen_int_6.c index 99b544db15..cf7dbd1732 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_6.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_6.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -97,28 +97,6 @@ void bli_saxpyf_zen_int_6 // operation as a loop over axpyv. if ( b_n != fuse_fac ) { -#ifdef BLIS_CONFIG_EPYC - for ( i = 0; i < b_n; ++i ) - { - float* a1 = a + (0 )*inca + (i )*lda; - float* chi1 = x + (i )*incx; - float* y1 = y + (0 )*incy; - float alpha_chi1; - - bli_scopycjs( conjx, *chi1, alpha_chi1 ); - bli_sscals( *alpha, alpha_chi1 ); - - bli_saxpyv_zen_int10 - ( - conja, - m, - &alpha_chi1, - a1, inca, - y1, incy, - cntx - ); - } -#else saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx ); for ( i = 0; i < b_n; ++i ) @@ -141,7 +119,7 @@ void bli_saxpyf_zen_int_6 cntx ); } -#endif + return; } diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index d9c4047ec4..4815d57d72 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2017-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2017-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -114,13 +114,9 @@ err_t bli_gemm_small AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; #else - // When dynamic dispatch is enabled i.e. library is built for 'amdzen' configuration. - // Invoke architecture specific kernels only if we are sure that we are running on zen, - // zen2 or zen3 otherwise fall back to reference kernels (via framework and context). - arch_t id = bli_arch_query_id(); - bool bamdzen = (id == BLIS_ARCH_ZEN3) || (id == BLIS_ARCH_ZEN2) || (id == BLIS_ARCH_ZEN); - - if (0 == bamdzen) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) { return BLIS_NOT_YET_IMPLEMENTED; } From b406e818464fe63d69fc7a9210190828f4d57b42 Mon Sep 17 00:00:00 2001 From: Saitharun Date: Wed, 19 Jan 2022 11:38:45 +0530 Subject: [PATCH 20/63] Enable wrapper code by default details: Changes Made for 4.0 branch to enable wrapper code by default and also removed ENABLE_API_WRAPPER macro. Change-Id: I5c9ede7ae959d811bc009073a266e66cbf07ef1a --- CMakeLists.txt | 7 +------ frame/util/bli_util_api_wrap.c | 4 +++- frame/util/bli_util_api_wrap.h | 4 +++- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index e2cb3818e6..ccf98af52d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -93,9 +93,8 @@ option(BLIS_ENABLE_ILP64 "ENABLE BLIS ILP64" OFF) option(ENABLE_INT_TYPE_SIZE " Internal BLIS integers ,used in native BLIS interfaces based on architecture dependent " ON) option(ENABLE_BLASTEST "Enable the blastest" OFF) option(ENABLE_TESTCPP_TESTING "Enabling testcpp" OFF) -option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" ON) +option (ENABLE_NO_UNDERSCORE_API "export APIs without underscore" OFF) option (ENABLE_UPPERCASE_API "export APIs with uppercase" OFF) -option (ENABLE_API_WRAPPER "Enable wrapper code" OFF) option (ENABLE_COMPLEX_RETURN_INTEL "Enable complex_return_intel" OFF) option (ENABLE_TRSM_PREINVERSION "Enable TRSM preinversion" ON) option (ENABLE_AOCL_DYNAMIC "Enable Dynamic Multi-threading" OFF) @@ -125,10 +124,6 @@ if(ENABLE_UPPERCASE_API) add_definitions(-DBLIS_ENABLE_UPPERCASE_API) endif() -if(ENABLE_API_WRAPPER) - add_definitions(-DBLIS_ENABLE_API_WRAPPER) -endif() - if(ENABLE_AOCL_DYNAMIC) set(AOCL_DYNAMIC TRUE) endif() diff --git a/frame/util/bli_util_api_wrap.c b/frame/util/bli_util_api_wrap.c index 128fba8b87..81300761fb 100644 --- a/frame/util/bli_util_api_wrap.c +++ b/frame/util/bli_util_api_wrap.c @@ -39,7 +39,8 @@ #include "bli_util_api_wrap.h" // wrapper functions to support additional symbols -#ifdef BLIS_ENABLE_API_WRAPPER +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API void CAXPY(const f77_int *n,const scomplex *ca,const scomplex *cx,const f77_int *incx,scomplex *cy,const f77_int *incy) { caxpy_( n, ca, cx, incx, cy, incy); @@ -3221,3 +3222,4 @@ void CAXPBY_( const f77_int* n, const scomplex* alpha, const scomplex *x, con } #endif +#endif diff --git a/frame/util/bli_util_api_wrap.h b/frame/util/bli_util_api_wrap.h index f0aff49ff2..78f088e28e 100644 --- a/frame/util/bli_util_api_wrap.h +++ b/frame/util/bli_util_api_wrap.h @@ -35,7 +35,8 @@ // file define different formats of BLAS APIs- uppercase with // and without underscore, lowercase without underscore. -#ifdef BLIS_ENABLE_API_WRAPPER +#ifndef BLIS_ENABLE_NO_UNDERSCORE_API +#ifndef BLIS_ENABLE_UPPERCASE_API //Level 1 APIs BLIS_EXPORT_BLIS void SROTG(float *sa, float *sb, float *c, float *s); @@ -1729,3 +1730,4 @@ BLIS_EXPORT_BLIS void ZOMATCOPY_(f77_char* trans, f77_int* rows, f77_int* cols #endif +#endif From 267f3092632946f026eff3513cfe314429afd563 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan Date: Fri, 7 Jan 2022 14:10:56 +0530 Subject: [PATCH 21/63] Improved performance of DOTXV kernel for float and double - Vectorized sections of code that were not vectorized AMD Internal: [CPUPL-1980] Change-Id: I08528d054442a5e728f631142f244f1624170136 --- kernels/zen/1/bli_dotxv_zen_int.c | 131 ++++++++++++++++++------------ 1 file changed, 78 insertions(+), 53 deletions(-) diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index 99ea517104..8ba1d1bba4 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2016 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Copyright (C) 2018, The University of Texas at Austin Redistribution and use in source and binary forms, with or without @@ -36,6 +36,14 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 128-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + /* Union data structure to access AVX registers One 256-bit AVX register holds 8 SP elements. */ typedef union @@ -44,6 +52,14 @@ typedef union float f[8] __attribute__((aligned(64))); } v8sf_t; +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 DP elements. */ +typedef union +{ + __m128d v; + double d[2] __attribute__((aligned(64))); +} v2df_t; + /* Union data structure to access AVX registers * One 256-bit AVX register holds 4 DP elements. */ typedef union @@ -78,11 +94,7 @@ void bli_sdotxv_zen_int float* restrict y0; float rho0; - v8sf_t rho0v, rho1v, rho2v, rho3v; - v8sf_t x0v, y0v; - v8sf_t x1v, y1v; - v8sf_t x2v, y2v; - v8sf_t x3v, y3v; + v8sf_t rhov[4], xv[4], yv[4]; // If beta is zero, initialize rho1 to zero instead of scaling // rho by beta (in case rho contains NaN or Inf). @@ -117,45 +129,55 @@ void bli_sdotxv_zen_int y0 = y; // Initialize the unrolled iterations' rho vectors to zero. - rho0v.v = _mm256_setzero_ps(); - rho1v.v = _mm256_setzero_ps(); - rho2v.v = _mm256_setzero_ps(); - rho3v.v = _mm256_setzero_ps(); + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); for ( i = 0; i < n_viter; ++i ) { // Load the x and y input vector elements. - x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); - y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); + xv[0].v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg ); - x1v.v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); - y1v.v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); + xv[1].v = _mm256_loadu_ps( x0 + 1*n_elem_per_reg ); + yv[1].v = _mm256_loadu_ps( y0 + 1*n_elem_per_reg ); - x2v.v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); - y2v.v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); + xv[2].v = _mm256_loadu_ps( x0 + 2*n_elem_per_reg ); + yv[2].v = _mm256_loadu_ps( y0 + 2*n_elem_per_reg ); - x3v.v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); - y3v.v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); + xv[3].v = _mm256_loadu_ps( x0 + 3*n_elem_per_reg ); + yv[3].v = _mm256_loadu_ps( y0 + 3*n_elem_per_reg ); // Compute the element-wise product of the x and y vectors, // storing in the corresponding rho vectors. - rho0v.v = _mm256_fmadd_ps( x0v.v, y0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_ps( x1v.v, y1v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_ps( x2v.v, y2v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_ps( x3v.v, y3v.v, rho3v.v ); + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); x0 += ( n_elem_per_reg * n_iter_unroll ); y0 += ( n_elem_per_reg * n_iter_unroll ); } // Accumulate the unrolled rho vectors into a single vector. - rho0v.v += rho1v.v; - rho0v.v += rho2v.v; - rho0v.v += rho3v.v; + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[1].v); + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[2].v); + rhov[0].v = _mm256_add_ps(rhov[0].v,rhov[3].v); + + v4sf_t inter0, inter1; + + inter0.v = _mm256_extractf128_ps(rhov[0].v,0); + inter1.v = _mm256_extractf128_ps(rhov[0].v,1); + + inter0.v = _mm_add_ps(inter0.v, inter1.v); + + inter1.v = _mm_permute_ps(inter0.v, 14); + + inter0.v = _mm_add_ps(inter0.v,inter1.v); // Accumulate the final rho vector into a single scalar result. - rho0 = rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] + - rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7]; + rho0 = inter0.f[0] + inter0.f[1]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when @@ -206,12 +228,8 @@ void bli_ddotxv_zen_int double* restrict y0; double rho0; - v4df_t rho0v, rho1v, rho2v, rho3v; - v4df_t x0v, y0v; - v4df_t x1v, y1v; - v4df_t x2v, y2v; - v4df_t x3v, y3v; - + v4df_t rhov[4], xv[4], yv[4]; + // If beta is zero, initialize rho1 to zero instead of scaling // rho by beta (in case rho contains NaN or Inf). if ( PASTEMAC(d,eq0)( *beta ) ) @@ -245,44 +263,51 @@ void bli_ddotxv_zen_int y0 = y; // Initialize the unrolled iterations' rho vectors to zero. - rho0v.v = _mm256_setzero_pd(); - rho1v.v = _mm256_setzero_pd(); - rho2v.v = _mm256_setzero_pd(); - rho3v.v = _mm256_setzero_pd(); + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); for ( i = 0; i < n_viter; ++i ) { // Load the x and y input vector elements. - x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + xv[0].v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - x1v.v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + xv[1].v = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - x2v.v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); - y2v.v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + xv[2].v = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); - x3v.v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); - y3v.v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + xv[3].v = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); // Compute the element-wise product of the x and y vectors, // storing in the corresponding rho vectors. - rho0v.v = _mm256_fmadd_pd( x0v.v, y0v.v, rho0v.v ); - rho1v.v = _mm256_fmadd_pd( x1v.v, y1v.v, rho1v.v ); - rho2v.v = _mm256_fmadd_pd( x2v.v, y2v.v, rho2v.v ); - rho3v.v = _mm256_fmadd_pd( x3v.v, y3v.v, rho3v.v ); + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); x0 += ( n_elem_per_reg * n_iter_unroll ); y0 += ( n_elem_per_reg * n_iter_unroll ); } // Accumulate the unrolled rho vectors into a single vector. - rho0v.v += rho1v.v; - rho0v.v += rho2v.v; - rho0v.v += rho3v.v; + rhov[0].v = _mm256_add_pd(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[3].v,rhov[0].v); + + v2df_t inter1, inter2; + + inter1.v = _mm256_extractf128_pd(rhov[0].v,1); + inter2.v = _mm256_extractf128_pd(rhov[0].v,0); + + inter1.v = _mm_add_pd(inter1.v, inter2.v); // Accumulate the final rho vector into a single scalar result. - rho0 = rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3]; + rho0 = inter1.d[0] + inter1.d[1]; // Issue vzeroupper instruction to clear upper lanes of ymm registers. // This avoids a performance penalty caused by false dependencies when From 67e95ae9872b1368f5316e2330a3e906fa2fb2de Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 1 Feb 2022 10:22:58 +0530 Subject: [PATCH 22/63] Optimized CPU feature determination. We added new API to check if the CPU architecture has support for AVX instruction. This API was calling CPUID instruction every time it is invoked. However, since this information does not change at runtime, it is sufficient to determine it once and use the cached results for subsequent calls. This optimization is needed to improve performance for small size matrix vector operations. AMD-Internal: [CPUPL-2009] Change-Id: If6697e1da6dd6b7f28fbfed45215ea3fdd569c5f --- frame/base/bli_cpuid.c | 45 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index db698e9d0f..f4251a8c5c 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Copyright (C) 2019, Dave Love, University of Manchester Redistribution and use in source and binary forms, with or without @@ -459,13 +459,23 @@ bool bli_cpuid_is_bulldozer return TRUE; } -bool bli_cpuid_is_avx_supported( void ) +// Check (at runtime) if AVX is supported on the current platform, this is to +// ensure that AVX kernels are not used on legacy platforms which results in crash + +// The support for AVX is checked only once (when this API is called first time) +// On subsequent calls the cached value is returned. This is achieved using +// pthread_once mechanism since this information does not change once the library +// is loaded. +static bool is_avx_supported = FALSE; + + +// Determine if the CPU has support for AVX. +void bli_cpuid_check_avx_support( void ) { uint32_t family, model, features; // Call the CPUID instruction and parse its results into a family id, - // model id, and a feature bit field. The return value encodes the - // vendor. + // model id, and a feature bit field. bli_cpuid_query( &family, &model, &features ); // Check for expected CPU features. @@ -473,9 +483,32 @@ bool bli_cpuid_is_avx_supported( void ) FEATURE_FMA3 | FEATURE_AVX2; - if ( !bli_cpuid_has_features( features, expected ) ) return FALSE; + if ( !bli_cpuid_has_features( features, expected ) ) + { + is_avx_supported = FALSE; + } + else + { + is_avx_supported = TRUE; + } +} - return TRUE; +static bli_pthread_once_t once_check_avx_support = BLIS_PTHREAD_ONCE_INIT; + +// Ensure that actual support determincation happens only once +void bli_cpuid_check_avx_support_once( void ) +{ +#ifndef BLIS_CONFIGURETIME_CPUID + bli_pthread_once( &once_check_avx_support, bli_cpuid_check_avx_support ); +#endif +} + +// API to check if AVX is supported or not on the current platform. +bool bli_cpuid_is_avx_supported( void ) +{ + bli_cpuid_check_avx_support_once(); + + return is_avx_supported; } #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) From d62c6392a4e12c02bb75dc2c8b6669ad63939dd7 Mon Sep 17 00:00:00 2001 From: Harihara Sudhan S Date: Fri, 28 Jan 2022 11:44:38 +0530 Subject: [PATCH 23/63] Improved DGEMV performance for column-major cases - Altered the framework to use 2 more fused kernels for better problem decomposition - Increased unroll factor in AXPYF5 and AXPYF8 kernels to improve register usage AMD-Internal: [CPUPL-1970] Change-Id: I79750235d9554466def5ff93898f832834990343 --- frame/2/gemv/bli_gemv_unf_var2_amd.c | 94 +++++- kernels/zen/1f/bli_axpyf_zen_int_5.c | 356 +++++++++++++-------- kernels/zen/1f/bli_axpyf_zen_int_8.c | 450 ++++++++++++++++++++------- 3 files changed, 653 insertions(+), 247 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var2_amd.c b/frame/2/gemv/bli_gemv_unf_var2_amd.c index d7f5145e31..831d906ca4 100644 --- a/frame/2/gemv/bli_gemv_unf_var2_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var2_amd.c @@ -313,27 +313,87 @@ void bli_dgemv_unf_var2 } } - for ( i = 0; i < n_iter; i += f ) + dim_t fuse_factor = 8; + dim_t f_temp = 0; + + // Change the fuse factor based on + // Input size and available kernels + // This ensures that fusing is possible when the number of + // left over colums is less (better problem decomposition) + if (n < 5) fuse_factor = 4; + else if (n < 8) fuse_factor = 5; + + for (i = 0; i < n_iter; i += f) { - f = bli_determine_blocksize_dim_f( i, n_iter, BLIS_DGEMV_VAR2_FUSE ); + f = bli_determine_blocksize_dim_f(i, n_iter, fuse_factor); - A1 = a + (0 )*rs_at + (i )*cs_at; - x1 = x + (i )*incx; + A1 = a + (i)*cs_at; + x1 = x + (i)*incx; - /* y = y + alpha * A1 * x1; */ - bli_daxpyf_zen_int_16x4 - ( - conja, - conjx, - n_elem, - f, - alpha, - A1, rs_at, cs_at, - x1, incx, - y_buf, buf_incy, - cntx - ); + // Pick kernel based on problem size + switch (f) + { + case 8: + + bli_daxpyf_zen_int_8( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + + break; + default: + + if (f < 5) + { + bli_daxpyf_zen_int_16x4( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + else + { + bli_daxpyf_zen_int_5( + conja, + conjx, + n_elem, + f, + alpha, + A1, rs_at, cs_at, + x1, incx, + y_buf, buf_incy, + cntx); + } + } + + // Calculate the next problem size + f_temp = bli_determine_blocksize_dim_f(i + f, n_iter, fuse_factor); + + // Change fuse factor based on the next problem size + if (f_temp < fuse_factor) + { + if (f_temp < 5) + { + fuse_factor = 4; + } + else + { + fuse_factor = 5; + } + } } + if ((incy > 1) && bli_mem_is_alloc( &mem_bufY )) { //store the result from unit strided y_buf to non-unit strided Y diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index d09a85f57f..8b1f697cec 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -329,27 +329,13 @@ void bli_daxpyf_zen_int_5 dim_t i; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; + double* restrict av[5] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v; - - v4df_t a00v, a01v, a02v, a03v; - v4df_t a04v; - - v4df_t a10v, a11v, a12v, a13v; - v4df_t a14v; - - v4df_t y0v, y1v; - - double chi0, chi1, chi2, chi3; - double chi4; + v4df_t chiv[5], a_vec[20], yv[4]; + + double chi[5]; // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || bli_deq0( *alpha ) ) return; @@ -385,117 +371,241 @@ void bli_daxpyf_zen_int_5 } // At this point, we know that b_n is exactly equal to the fusing factor. - - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; + // av points to the 5 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); // Scale each chi scalar by alpha. - bli_dscals( *alpha, chi0 ); - bli_dscals( *alpha, chi1 ); - bli_dscals( *alpha, chi2 ); - bli_dscals( *alpha, chi3 ); - bli_dscals( *alpha, chi4 ); + bli_dscals( *alpha, chi[0] ); + bli_dscals( *alpha, chi[1] ); + bli_dscals( *alpha, chi[2] ); + bli_dscals( *alpha, chi[3] ); + bli_dscals( *alpha, chi[4] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); // If there are vectorized iterations, perform them with vector // instructions. if ( inca == 1 && incy == 1 ) { - for ( i = 0; (i + 7) < m; i += 8 ) + // 16 elements of the result are computed per iteration + for ( i = 0; (i + 15) < m; i += 16 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - y1v.v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + + a_vec[15].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[16].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a10v.v = _mm256_loadu_pd( a0 + 1*n_elem_per_reg ); - - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a11v.v = _mm256_loadu_pd( a1 + 1*n_elem_per_reg ); - - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a12v.v = _mm256_loadu_pd( a2 + 1*n_elem_per_reg ); + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[15].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[16].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[17].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[18].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[19].v, chiv[4].v, yv[3].v ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a13v.v = _mm256_loadu_pd( a3 + 1*n_elem_per_reg ); + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * 4; + av[0] += n_elem_per_reg * 4; + av[1] += n_elem_per_reg * 4; + av[2] += n_elem_per_reg * 4; + av[3] += n_elem_per_reg * 4; + av[4] += n_elem_per_reg * 4; + } - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a14v.v = _mm256_loadu_pd( a4 + 1*n_elem_per_reg ); + // 12 elements of the result are computed per iteration + for ( ; (i + 11) < m; i += 12 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + + a_vec[10].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a10v.v, chi0v.v, y1v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[10].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[11].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[12].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[13].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[14].v, chiv[4].v, yv[2].v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a11v.v, chi1v.v, y1v.v ); + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * 3; + av[0] += n_elem_per_reg * 3; + av[1] += n_elem_per_reg * 3; + av[2] += n_elem_per_reg * 3; + av[3] += n_elem_per_reg * 3; + av[4] += n_elem_per_reg * 3; + } - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a12v.v, chi2v.v, y1v.v ); + // 8 elements of the result are computed per iteration + for (; (i + 7) < m; i += 8 ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a13v.v, chi3v.v, y1v.v ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); - y1v.v = _mm256_fmadd_pd( a14v.v, chi4v.v, y1v.v ); + a_vec[5].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[8].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[5].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[6].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[7].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[4].v, yv[1].v ); // Store the output. - _mm256_storeu_pd( (double *)(y0 + 0*n_elem_per_reg), y0v.v ); - _mm256_storeu_pd( (double *)(y0 + 1*n_elem_per_reg), y1v.v ); - - y0 += n_iter_unroll * n_elem_per_reg; - a0 += n_iter_unroll * n_elem_per_reg; - a1 += n_iter_unroll * n_elem_per_reg; - a2 += n_iter_unroll * n_elem_per_reg; - a3 += n_iter_unroll * n_elem_per_reg; - a4 += n_iter_unroll * n_elem_per_reg; + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * 2; + av[0] += n_elem_per_reg * 2; + av[1] += n_elem_per_reg * 2; + av[2] += n_elem_per_reg * 2; + av[3] += n_elem_per_reg * 2; + av[4] += n_elem_per_reg * 2; } + // 4 elements of the result are computed per iteration for( ; (i + 3) < m; i += 4 ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - - a00v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a01v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a02v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a03v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a04v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a00v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a01v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a02v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a03v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a04v.v, chi4v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -503,25 +613,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += 1; - a1 += 1; - a2 += 1; - a3 += 1; - a4 += 1; + av[0] += 1; + av[1] += 1; + av[2] += 1; + av[3] += 1; + av[4] += 1; y0 += 1; } } @@ -531,25 +641,25 @@ void bli_daxpyf_zen_int_5 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; y0 += incy; } @@ -1153,7 +1263,7 @@ void bli_daxpyf_zen_int_16x4 a2 += n_elem_per_reg; a3 += n_elem_per_reg; } -#if 1 + for ( ; (i + 1) < m; i += 2) { @@ -1186,7 +1296,7 @@ void bli_daxpyf_zen_int_16x4 a2 += 2; a3 += 2; } -#endif + // If there are leftover iterations, perform them with scalar code. for ( ; (i + 0) < m ; ++i ) { diff --git a/kernels/zen/1f/bli_axpyf_zen_int_8.c b/kernels/zen/1f/bli_axpyf_zen_int_8.c index b958600ce6..27dafb28fc 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_8.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2018, The University of Texas at Austin - Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc. + Copyright (C) 2016 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -279,32 +279,19 @@ void bli_daxpyf_zen_int_8 const dim_t fuse_fac = 8; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 1; + const dim_t n_iter_unroll[4] = {4, 3, 2, 1}; dim_t i; - dim_t m_viter; - dim_t m_left; + dim_t m_viter[4]; + dim_t m_left = m; - double* restrict a0; - double* restrict a1; - double* restrict a2; - double* restrict a3; - double* restrict a4; - double* restrict a5; - double* restrict a6; - double* restrict a7; + double* restrict av[8] __attribute__((aligned(64))); double* restrict y0; - v4df_t chi0v, chi1v, chi2v, chi3v; - v4df_t chi4v, chi5v, chi6v, chi7v; + v4df_t chiv[8], a_vec[32], yv[4]; - v4df_t a0v, a1v, a2v, a3v; - v4df_t a4v, a5v, a6v, a7v; - v4df_t y0v; - - double chi0, chi1, chi2, chi3; - double chi4, chi5, chi6, chi7; + double chi[8] __attribute__((aligned(64))); // If either dimension is zero, or if alpha is zero, return early. if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return; @@ -343,94 +330,343 @@ void bli_daxpyf_zen_int_8 // Use the unrolling factor and the number of elements per register // to compute the number of vectorized and leftover iterations. - m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll ); - m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll ); + m_viter[0] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[0] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[0] ); + + m_viter[1] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[1] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[1] ); + + m_viter[2] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[2] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[2] ); + + m_viter[3] = ( m_left ) / ( n_elem_per_reg * n_iter_unroll[3] ); + m_left = ( m_left ) % ( n_elem_per_reg * n_iter_unroll[3] ); // If there is anything that would interfere with our use of contiguous // vector loads/stores, override m_viter and m_left to use scalar code // for all iterations. if ( inca != 1 || incy != 1 ) { - m_viter = 0; + m_viter[0] = m_viter[1] = m_viter[2] = m_viter[3] = 0; m_left = m; } - a0 = a + 0*lda; - a1 = a + 1*lda; - a2 = a + 2*lda; - a3 = a + 3*lda; - a4 = a + 4*lda; - a5 = a + 5*lda; - a6 = a + 6*lda; - a7 = a + 7*lda; + // av points to the 8 columns under consideration + av[0] = a + 0*lda; + av[1] = a + 1*lda; + av[2] = a + 2*lda; + av[3] = a + 3*lda; + av[4] = a + 4*lda; + av[5] = a + 5*lda; + av[6] = a + 6*lda; + av[7] = a + 7*lda; y0 = y; - chi0 = *( x + 0*incx ); - chi1 = *( x + 1*incx ); - chi2 = *( x + 2*incx ); - chi3 = *( x + 3*incx ); - chi4 = *( x + 4*incx ); - chi5 = *( x + 5*incx ); - chi6 = *( x + 6*incx ); - chi7 = *( x + 7*incx ); + chi[0] = *( x + 0*incx ); + chi[1] = *( x + 1*incx ); + chi[2] = *( x + 2*incx ); + chi[3] = *( x + 3*incx ); + chi[4] = *( x + 4*incx ); + chi[5] = *( x + 5*incx ); + chi[6] = *( x + 6*incx ); + chi[7] = *( x + 7*incx ); // Scale each chi scalar by alpha. - PASTEMAC(d,scals)( *alpha, chi0 ); - PASTEMAC(d,scals)( *alpha, chi1 ); - PASTEMAC(d,scals)( *alpha, chi2 ); - PASTEMAC(d,scals)( *alpha, chi3 ); - PASTEMAC(d,scals)( *alpha, chi4 ); - PASTEMAC(d,scals)( *alpha, chi5 ); - PASTEMAC(d,scals)( *alpha, chi6 ); - PASTEMAC(d,scals)( *alpha, chi7 ); + PASTEMAC(d,scals)( *alpha, chi[0] ); + PASTEMAC(d,scals)( *alpha, chi[1] ); + PASTEMAC(d,scals)( *alpha, chi[2] ); + PASTEMAC(d,scals)( *alpha, chi[3] ); + PASTEMAC(d,scals)( *alpha, chi[4] ); + PASTEMAC(d,scals)( *alpha, chi[5] ); + PASTEMAC(d,scals)( *alpha, chi[6] ); + PASTEMAC(d,scals)( *alpha, chi[7] ); // Broadcast the (alpha*chi?) scalars to all elements of vector registers. - chi0v.v = _mm256_broadcast_sd( &chi0 ); - chi1v.v = _mm256_broadcast_sd( &chi1 ); - chi2v.v = _mm256_broadcast_sd( &chi2 ); - chi3v.v = _mm256_broadcast_sd( &chi3 ); - chi4v.v = _mm256_broadcast_sd( &chi4 ); - chi5v.v = _mm256_broadcast_sd( &chi5 ); - chi6v.v = _mm256_broadcast_sd( &chi6 ); - chi7v.v = _mm256_broadcast_sd( &chi7 ); + chiv[0].v = _mm256_broadcast_sd( &chi[0] ); + chiv[1].v = _mm256_broadcast_sd( &chi[1] ); + chiv[2].v = _mm256_broadcast_sd( &chi[2] ); + chiv[3].v = _mm256_broadcast_sd( &chi[3] ); + chiv[4].v = _mm256_broadcast_sd( &chi[4] ); + chiv[5].v = _mm256_broadcast_sd( &chi[5] ); + chiv[6].v = _mm256_broadcast_sd( &chi[6] ); + chiv[7].v = _mm256_broadcast_sd( &chi[7] ); // If there are vectorized iterations, perform them with vector // instructions. - for ( i = 0; i < m_viter; ++i ) + // 16 elements of the result are computed per iteration + for ( i = 0; i < m_viter[0]; ++i ) { // Load the input values. - y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); - a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg ); - a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg ); - a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg ); - a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg ); - a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg ); - a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg ); - a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg ); - a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg ); + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3].v = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + a_vec[24].v = _mm256_loadu_pd( av[0] + 3*n_elem_per_reg ); + a_vec[25].v = _mm256_loadu_pd( av[1] + 3*n_elem_per_reg ); + a_vec[26].v = _mm256_loadu_pd( av[2] + 3*n_elem_per_reg ); + a_vec[27].v = _mm256_loadu_pd( av[3] + 3*n_elem_per_reg ); + a_vec[28].v = _mm256_loadu_pd( av[4] + 3*n_elem_per_reg ); + a_vec[29].v = _mm256_loadu_pd( av[5] + 3*n_elem_per_reg ); + a_vec[30].v = _mm256_loadu_pd( av[6] + 3*n_elem_per_reg ); + a_vec[31].v = _mm256_loadu_pd( av[7] + 3*n_elem_per_reg ); // perform : y += alpha * x; - y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v ); - y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v ); + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + yv[3].v = _mm256_fmadd_pd( a_vec[24].v, chiv[0].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[25].v, chiv[1].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[26].v, chiv[2].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[27].v, chiv[3].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[28].v, chiv[4].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[29].v, chiv[5].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[30].v, chiv[6].v, yv[3].v ); + yv[3].v = _mm256_fmadd_pd( a_vec[31].v, chiv[7].v, yv[3].v ); // Store the output. - _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v ); + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + _mm256_storeu_pd( (y0 + 3*n_elem_per_reg), yv[3].v ); + + y0 += n_elem_per_reg * n_iter_unroll[0]; + av[0] += n_elem_per_reg * n_iter_unroll[0]; + av[1] += n_elem_per_reg * n_iter_unroll[0]; + av[2] += n_elem_per_reg * n_iter_unroll[0]; + av[3] += n_elem_per_reg * n_iter_unroll[0]; + av[4] += n_elem_per_reg * n_iter_unroll[0]; + av[5] += n_elem_per_reg * n_iter_unroll[0]; + av[6] += n_elem_per_reg * n_iter_unroll[0]; + av[7] += n_elem_per_reg * n_iter_unroll[0]; + } + + // 12 elements of the result are computed per iteration + for ( i = 0; i < m_viter[1]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2].v = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + a_vec[16].v = _mm256_loadu_pd( av[0] + 2*n_elem_per_reg ); + a_vec[17].v = _mm256_loadu_pd( av[1] + 2*n_elem_per_reg ); + a_vec[18].v = _mm256_loadu_pd( av[2] + 2*n_elem_per_reg ); + a_vec[19].v = _mm256_loadu_pd( av[3] + 2*n_elem_per_reg ); + a_vec[20].v = _mm256_loadu_pd( av[4] + 2*n_elem_per_reg ); + a_vec[21].v = _mm256_loadu_pd( av[5] + 2*n_elem_per_reg ); + a_vec[22].v = _mm256_loadu_pd( av[6] + 2*n_elem_per_reg ); + a_vec[23].v = _mm256_loadu_pd( av[7] + 2*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + yv[2].v = _mm256_fmadd_pd( a_vec[16].v, chiv[0].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[17].v, chiv[1].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[18].v, chiv[2].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[19].v, chiv[3].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[20].v, chiv[4].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[21].v, chiv[5].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[22].v, chiv[6].v, yv[2].v ); + yv[2].v = _mm256_fmadd_pd( a_vec[23].v, chiv[7].v, yv[2].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + _mm256_storeu_pd( (y0 + 2*n_elem_per_reg), yv[2].v ); + + y0 += n_elem_per_reg * n_iter_unroll[1]; + av[0] += n_elem_per_reg * n_iter_unroll[1]; + av[1] += n_elem_per_reg * n_iter_unroll[1]; + av[2] += n_elem_per_reg * n_iter_unroll[1]; + av[3] += n_elem_per_reg * n_iter_unroll[1]; + av[4] += n_elem_per_reg * n_iter_unroll[1]; + av[5] += n_elem_per_reg * n_iter_unroll[1]; + av[6] += n_elem_per_reg * n_iter_unroll[1]; + av[7] += n_elem_per_reg * n_iter_unroll[1]; + } + + // 8 elements of the result are computed per iteration + for ( i = 0; i < m_viter[2]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1].v = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + a_vec[8].v = _mm256_loadu_pd( av[0] + 1*n_elem_per_reg ); + a_vec[9].v = _mm256_loadu_pd( av[1] + 1*n_elem_per_reg ); + a_vec[10].v = _mm256_loadu_pd( av[2] + 1*n_elem_per_reg ); + a_vec[11].v = _mm256_loadu_pd( av[3] + 1*n_elem_per_reg ); + a_vec[12].v = _mm256_loadu_pd( av[4] + 1*n_elem_per_reg ); + a_vec[13].v = _mm256_loadu_pd( av[5] + 1*n_elem_per_reg ); + a_vec[14].v = _mm256_loadu_pd( av[6] + 1*n_elem_per_reg ); + a_vec[15].v = _mm256_loadu_pd( av[7] + 1*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + yv[1].v = _mm256_fmadd_pd( a_vec[8].v, chiv[0].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[9].v, chiv[1].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[10].v, chiv[2].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[11].v, chiv[3].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[12].v, chiv[4].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[13].v, chiv[5].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[14].v, chiv[6].v, yv[1].v ); + yv[1].v = _mm256_fmadd_pd( a_vec[15].v, chiv[7].v, yv[1].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); + _mm256_storeu_pd( (y0 + 1*n_elem_per_reg), yv[1].v ); + + y0 += n_elem_per_reg * n_iter_unroll[2]; + av[0] += n_elem_per_reg * n_iter_unroll[2]; + av[1] += n_elem_per_reg * n_iter_unroll[2]; + av[2] += n_elem_per_reg * n_iter_unroll[2]; + av[3] += n_elem_per_reg * n_iter_unroll[2]; + av[4] += n_elem_per_reg * n_iter_unroll[2]; + av[5] += n_elem_per_reg * n_iter_unroll[2]; + av[6] += n_elem_per_reg * n_iter_unroll[2]; + av[7] += n_elem_per_reg * n_iter_unroll[2]; + } + + // 4 elements of the result are computed per iteration + for ( i = 0; i < m_viter[3]; ++i ) + { + // Load the input values. + yv[0].v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + a_vec[0].v = _mm256_loadu_pd( av[0] + 0*n_elem_per_reg ); + a_vec[1].v = _mm256_loadu_pd( av[1] + 0*n_elem_per_reg ); + a_vec[2].v = _mm256_loadu_pd( av[2] + 0*n_elem_per_reg ); + a_vec[3].v = _mm256_loadu_pd( av[3] + 0*n_elem_per_reg ); + a_vec[4].v = _mm256_loadu_pd( av[4] + 0*n_elem_per_reg ); + a_vec[5].v = _mm256_loadu_pd( av[5] + 0*n_elem_per_reg ); + a_vec[6].v = _mm256_loadu_pd( av[6] + 0*n_elem_per_reg ); + a_vec[7].v = _mm256_loadu_pd( av[7] + 0*n_elem_per_reg ); + + // perform : y += alpha * x; + yv[0].v = _mm256_fmadd_pd( a_vec[0].v, chiv[0].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[1].v, chiv[1].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[2].v, chiv[2].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[3].v, chiv[3].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[4].v, chiv[4].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[5].v, chiv[5].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[6].v, chiv[6].v, yv[0].v ); + yv[0].v = _mm256_fmadd_pd( a_vec[7].v, chiv[7].v, yv[0].v ); + + // Store the output. + _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), yv[0].v ); y0 += n_elem_per_reg; - a0 += n_elem_per_reg; - a1 += n_elem_per_reg; - a2 += n_elem_per_reg; - a3 += n_elem_per_reg; - a4 += n_elem_per_reg; - a5 += n_elem_per_reg; - a6 += n_elem_per_reg; - a7 += n_elem_per_reg; + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + av[4] += n_elem_per_reg; + av[5] += n_elem_per_reg; + av[6] += n_elem_per_reg; + av[7] += n_elem_per_reg; } // If there are leftover iterations, perform them with scalar code. @@ -438,34 +674,34 @@ void bli_daxpyf_zen_int_8 { double y0c = *y0; - const double a0c = *a0; - const double a1c = *a1; - const double a2c = *a2; - const double a3c = *a3; - const double a4c = *a4; - const double a5c = *a5; - const double a6c = *a6; - const double a7c = *a7; - - y0c += chi0 * a0c; - y0c += chi1 * a1c; - y0c += chi2 * a2c; - y0c += chi3 * a3c; - y0c += chi4 * a4c; - y0c += chi5 * a5c; - y0c += chi6 * a6c; - y0c += chi7 * a7c; + const double a0c = *av[0]; + const double a1c = *av[1]; + const double a2c = *av[2]; + const double a3c = *av[3]; + const double a4c = *av[4]; + const double a5c = *av[5]; + const double a6c = *av[6]; + const double a7c = *av[7]; + + y0c += chi[0] * a0c; + y0c += chi[1] * a1c; + y0c += chi[2] * a2c; + y0c += chi[3] * a3c; + y0c += chi[4] * a4c; + y0c += chi[5] * a5c; + y0c += chi[6] * a6c; + y0c += chi[7] * a7c; *y0 = y0c; - a0 += inca; - a1 += inca; - a2 += inca; - a3 += inca; - a4 += inca; - a5 += inca; - a6 += inca; - a7 += inca; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + av[4] += inca; + av[5] += inca; + av[6] += inca; + av[7] += inca; y0 += incy; } } From d95629dcfc1227e3b881ace7932e07254955c9ad Mon Sep 17 00:00:00 2001 From: "Dipal M. Zambare" Date: Fri, 11 Feb 2022 12:12:01 +0530 Subject: [PATCH 24/63] Updated version number to 3.2 Change-Id: Iea5712d8cb854d4eaffea510e0fe2d9657e4d21f --- so_version | 2 +- version | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/so_version b/so_version index b1f189286c..8efd5969fe 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.1 +2.0 diff --git a/version b/version index 1795fa298a..252fb77212 100644 --- a/version +++ b/version @@ -1,2 +1,2 @@ -3.1.1 +3.2.0 From c13d981f03a3aee62fe629acb53797cafc153038 Mon Sep 17 00:00:00 2001 From: Meghana Vankadari Date: Wed, 2 Feb 2022 15:28:09 +0530 Subject: [PATCH 25/63] Fixed a bug in deriving dimensions from objects in gemm_front files Change-Id: I1f796c3a7ce6efacb6ef64651a7818b7ee38c6bb --- frame/3/gemm/bli_gemm_front.c | 10 +++------- frame/3/gemm/bli_gemm_front_amd.c | 2 +- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index d19d2eaea3..46e163c026 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -174,10 +174,6 @@ void bli_gemm_front bli_obj_swap_pack_schemas( &a_local, &b_local ); } - dim_t m_dim_local = bli_obj_length( &c_local ); - dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width( &a_local ); - // Parse and interpret the contents of the rntm_t object to properly // set the ways of parallelism for each loop, and then make any // additional modifications necessary for the current operation. @@ -185,9 +181,9 @@ void bli_gemm_front ( BLIS_GEMM, BLIS_LEFT, // ignored for gemm/hemm/symm - m_dim_local, - n_dim_local, - k_dim_local, + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width_after_trans( &a_local ), rntm ); diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index 41af62007c..a29a0bb85b 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -176,7 +176,7 @@ void bli_gemm_front dim_t m_dim_local = bli_obj_length( &c_local ); dim_t n_dim_local = bli_obj_width( &c_local ); - dim_t k_dim_local = bli_obj_width( &a_local ); + dim_t k_dim_local = bli_obj_width_after_trans( &a_local ); // Regression observed in sgemm native path in cases where m >= 4 * n // after BLIS_THREAD_RATIO_M updated from 2 to 1 as part of commit From 2aa710e36b8ae97dff33a9370ae75c68efccd51b Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Mon, 14 Feb 2022 17:43:41 +0530 Subject: [PATCH 26/63] AOCL_Windows: Updated windows build system. Updated the windows build system to link the user given openmp library using -DOpenMP_libomp_LIBRARY= option using command line or through cmake-gui application to build blis library and its test applications. If user not given any openmp library then by default openmp library will be C:/Program Files/LLVM/lib/libomp.lib. Change-Id: I07542c79454496f88e65e26327ad76a7f49c7a8c --- CMakeLists.txt | 13 ++++-- test/CMakeLists.txt | 98 ++++++++++++++++++++-------------------- testsuite/CMakeLists.txt | 6 +-- 3 files changed, 60 insertions(+), 57 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index ccf98af52d..8ba483e36b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,8 @@ set(CMAKE_RUNTIME_OUTPUT_DIRECTORY "${CMAKE_SOURCE_DIR}/bin") SET(AOCL_BLIS_FAMILY "zen" CACHE STRING "AOCL BLIS family name") -SET(OMP_LIB "C:\\Program Files\\LLVM\\lib\\libomp.lib" CACHE STRING "openmp library path") +SET(OpenMP_libomp_LIBRARY "C:/Program Files/LLVM/lib/libomp.lib" CACHE STRING "openmp library +path") set(TARGET_ARCH ${AOCL_BLIS_FAMILY}) set(AOCL_BLIS_ZEN TRUE) set (PYTHON_EXE "python") @@ -258,6 +259,9 @@ if(ENABLE_MULTITHREADING) find_package(OpenMP) if (OPENMP_FOUND) set(BLIS_ENABLE_OPENMP TRUE) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} ${OpenMP_EXE_LINKER_FLAGS}") else() message (FATAL_ERROR "Openmp Not Found") endif() @@ -526,14 +530,12 @@ file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) add_definitions(-DBLIS_VERSION_STRING="AOCL BLIS ${BLIS_VERSION_STRING}") -message( STATUS "OPENMP Library:" ${OMP_LIB}) - if(BUILD_SHARED_LIBS) add_library("${PROJECT_NAME}" SHARED ${CMAKE_SOURCE_DIR}/bli_config.h ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - target_link_libraries("${PROJECT_NAME}" PUBLIC "${OMP_LIB}") + target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() target_compile_definitions("${PROJECT_NAME}" PUBLIC -DBLIS_IS_BUILDING_LIBRARY) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") @@ -543,9 +545,10 @@ if(NOT BUILD_SHARED_LIBS) ${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/blis.h ${headers}) if(ENABLE_OPENMP) - set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OMP_LIB}") + set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") else() set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") + target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() endif() diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index fe8f7bac98..d116e942d0 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,172 +1,172 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## add_definitions(-DBLAS="AOCL") add_executable(TestAminv test_aminv.c) target_link_libraries(TestAminv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestAminv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestAminv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAminv optimized "${LIB_NAME}.lib") add_executable(TestAxpyv test_axpyv.c) target_link_libraries(TestAxpyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestAxpyv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestAxpyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAxpyv optimized "${LIB_NAME}.lib") add_executable(TestAxpbyv test_axpbyv.c) target_link_libraries(TestAxpbyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestAxpbyv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestAxpbyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestAxpbyv optimized "${LIB_NAME}.lib") add_executable(TestCopyv test_copyv.c) target_link_libraries(TestCopyv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestCopyv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestCopyv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestCopyv optimized "${LIB_NAME}.lib") add_executable(TestCabs1 test_cabs1.c) target_link_libraries(TestCabs1 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestCabs1 "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestCabs1 OpenMP::OpenMP_CXX) endif() target_link_libraries(TestCabs1 optimized "${LIB_NAME}.lib") add_executable(TestDotv test_dotv.c) target_link_libraries(TestDotv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestDotv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestDotv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestDotv optimized "${LIB_NAME}.lib") add_executable(TestGemm test_gemm.c) target_link_libraries(TestGemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemm optimized "${LIB_NAME}.lib") add_executable(TestGemmBatch test_gemm_batch.c) target_link_libraries(TestGemmBatch debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemmBatch "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemmBatch OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemmBatch optimized "${LIB_NAME}.lib") add_executable(TestGemm3m test_gemm3m.c) target_link_libraries(TestGemm3m debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemm3m "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemm3m OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemm3m optimized "${LIB_NAME}.lib") add_executable(TestGemmt test_gemmt.c) target_link_libraries(TestGemmt debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemmt "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemmt OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemmt optimized "${LIB_NAME}.lib") add_executable(TestGemv test_gemv.c) target_link_libraries(TestGemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGemv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGemv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGemv optimized "${LIB_NAME}.lib") add_executable(TestGer test_ger.c) target_link_libraries(TestGer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestGer "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestGer OpenMP::OpenMP_CXX) endif() target_link_libraries(TestGer optimized "${LIB_NAME}.lib") add_executable(TestHemm test_hemm.c) target_link_libraries(TestHemm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHemm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHemm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHemm optimized "${LIB_NAME}.lib") add_executable(TestHemv test_hemv.c) target_link_libraries(TestHemv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHemv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHemv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHemv optimized "${LIB_NAME}.lib") add_executable(TestHer test_her.c) target_link_libraries(TestHer debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHer "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHer OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer optimized "${LIB_NAME}.lib") add_executable(TestHer2 test_her2.c) target_link_libraries(TestHer2 debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHer2 "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHer2 OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer2 optimized "${LIB_NAME}.lib") add_executable(TestHer2k test_her2k.c) target_link_libraries(TestHer2k debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHer2k "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHer2k OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHer2k optimized "${LIB_NAME}.lib") add_executable(TestHerk test_herk.c) target_link_libraries(TestHerk debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestHerk "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestHerk OpenMP::OpenMP_CXX) endif() target_link_libraries(TestHerk optimized "${LIB_NAME}.lib") add_executable(TestScalv test_scalv.c) target_link_libraries(TestScalv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestScalv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestScalv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestScalv optimized "${LIB_NAME}.lib") add_executable(TestSwapv test_swapv.c) target_link_libraries(TestSwapv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestSwapv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestSwapv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestSwapv optimized "${LIB_NAME}.lib") add_executable(TestTrmm test_trmm.c) target_link_libraries(TestTrmm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrmm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrmm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrmm optimized "${LIB_NAME}.lib") add_executable(TestTrmv test_trmv.c) target_link_libraries(TestTrmv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrmv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrmv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrmv optimized "${LIB_NAME}.lib") add_executable(TestTrsm test_trsm.c) target_link_libraries(TestTrsm debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrsm "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrsm OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrsm optimized "${LIB_NAME}.lib") add_executable(TestTrsv test_trsv.c) target_link_libraries(TestTrsv debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(TestTrsv "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(TestTrsv OpenMP::OpenMP_CXX) endif() target_link_libraries(TestTrsv optimized "${LIB_NAME}.lib") diff --git a/testsuite/CMakeLists.txt b/testsuite/CMakeLists.txt index 613f9e3861..85866926dd 100644 --- a/testsuite/CMakeLists.txt +++ b/testsuite/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src) @@ -7,8 +7,8 @@ add_executable(test_libblis "") add_subdirectory(src) target_link_libraries(test_libblis debug "${LIB_NAME}.lib") -if(ENABLE_OPENMP AND BUILD_SHARED_LIBS) - target_link_libraries(test_libblis "${OMP_LIB}") +if(ENABLE_OPENMP) + target_link_libraries(test_libblis OpenMP::OpenMP_CXX) endif() target_link_libraries(test_libblis optimized "${LIB_NAME}.lib") From 018a10826e066ca173eddcaa3fea648422b4e986 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Tue, 22 Feb 2022 12:08:53 +0530 Subject: [PATCH 27/63] Optimized ZAXPY2V using AVX2 Intrinsics Details: - Intrinsic implementation of ZAXPY2V fused kernel for AVX2 - Updated definitions in zen contexts AMD-Internal: [CPUPL-2023] Change-Id: I8889ae08c826d26e66ae607c416c4282136937fa --- config/zen/bli_cntx_init_zen.c | 3 +- config/zen2/bli_cntx_init_zen2.c | 3 +- config/zen3/bli_cntx_init_zen3.c | 3 +- kernels/zen/1f/bli_axpy2v_zen_int.c | 533 ++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 1 + 5 files changed, 540 insertions(+), 3 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index eed39b3149..1badc24f96 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,7 +80,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 9, + 10, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, @@ -93,6 +93,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, //axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index f6b8eef1e4..997ccdba2e 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 9, + 10, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -105,6 +105,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index a043d5ad22..61fefdbc31 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,7 +92,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 9, + 10, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, @@ -105,6 +105,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) BLIS_DOTXF_KER, BLIS_SCOMPLEX, bli_cdotxf_zen_int_6, // axpy2v BLIS_AXPY2V_KER, BLIS_DOUBLE, bli_daxpy2v_zen_int, + BLIS_AXPY2V_KER, BLIS_DCOMPLEX, bli_zaxpy2v_zen_int, cntx ); diff --git a/kernels/zen/1f/bli_axpy2v_zen_int.c b/kernels/zen/1f/bli_axpy2v_zen_int.c index 4ddca52162..cba0141376 100644 --- a/kernels/zen/1f/bli_axpy2v_zen_int.c +++ b/kernels/zen/1f/bli_axpy2v_zen_int.c @@ -186,3 +186,536 @@ void bli_daxpy2v_zen_int ); } } + +/** + * zaxpy2v kernel performs axpy2v operation. + * z := z + alphax * conjx(x) + alphay * conjy(y) + * where, + * x, y & z are double complex vectors of length n. + * alpha & beta are complex scalers. + */ +void bli_zaxpy2v_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alphax, + dcomplex* restrict alphay, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_4) + + // If the vectors are empty or if both alpha are zero, return early + if ( ( bli_zero_dim1( n ) ) || + ( PASTEMAC(z,eq0)( *alphax ) && PASTEMAC(z,eq0)( *alphay ) ) ) { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) + return; + } + + const dim_t n_elem_per_reg = 4; // Number of elements per register + + dim_t i = 0; // Iterator + + double* restrict x0; + double* restrict y0; + double* restrict z0; + double* restrict alphax0; + double* restrict alphay0; + + // Initialize local pointers. + x0 = (double*) x; + y0 = (double*) y; + z0 = (double*) z; + alphax0 = (double*) alphax; + alphay0 = (double*) alphay; + + if ( incx == 1 && incy == 1 && incz == 1 ) + { + //---------- Scalar algorithm BLIS_NO_CONJUGATE ------------- + // + // z = z + alphax * x + alphay * y + // z = ( zR + izI ) + + // ( axR + iaxI ) * ( xR + ixI ) + + // ( ayR + iayI ) * ( yR + iyI ) + // z = ( zR + izI ) + + // ( axR.xR + iaxR.xI + iaxI.xR - axI.xI ) + + // ( xyR.yR + iayR.yI + iayI.yR - ayI.yI ) + // z = ( zR + izI ) + + // ( ( axR.xR - axI.xI ) + i( axR.xI + axI.xR ) ) + + // ( ( ayR.yR - ayI.yI ) + i( ayR.yI + ayI.yR ) ) + // z = ( zR + axR.xR - axI.xI + ayR.yR - ayI.yI ) + + // i( zI + axR.xI + axI.xR + ayR.yI + ayI.yR ) + // + // SIMD Algorithm BLIS_NO_CONJUGATE + // xv = xR0 xI0 xR1 xI1 + // xv' = xI0 xR0 xI1 xR1 + // yv = yR0 yI0 yR1 yI1 + // yv' = yI0 yR0 yI1 yR1 + // zv = zR0 zI0 zR1 zI1 + // zv' = zI0 zR0 zI1 zR1 + // axrv = axR axR axR axR + // axiv = -axI axI -axI axI + // ayrv = ayR ayR ayR ayR + // ayiv = -ayI ayI -ayI ayI + // + // step 1: FMA zv = zv + axrv * xv + // step 2: shuffle xv -> xv' + // step 3: FMA zv = zv + axiv * xv' + // step 4: FMA zv = zv + ayrv * yv + // step 5: shuffle yv -> xyv' + // step 6: FMA zv = zv + ayiv * yv' + + //---------- Scalar algorithm BLIS_CONJUGATE ------------- + // + // z = z + alphax * x + alphay * y + // z = ( zR + izI ) + + // ( axR + iaxI ) * ( xR - ixI ) + + // ( ayR + iayI ) * ( yR - iyI ) + // z = ( zR + izI ) + + // ( axR.xR - iaxR.xI + iaxI.xR + axI.xI ) + + // ( xyR.yR - iayR.yI + iayI.yR + ayI.yI ) + // z = ( zR + izI ) + + // ( ( axR.xR + axI.xI ) + i( -axR.xI + axI.xR ) ) + + // ( ( ayR.yR + ayI.yI ) + i( -ayR.yI + ayI.yR ) ) + // z = ( zR + axR.xR + axI.xI + ayR.yR + ayI.yI ) + + // i( zI - axR.xI + axI.xR - ayR.yI + ayI.yR ) + // + // SIMD Algorithm BLIS_CONJUGATE + // xv = xR0 xI0 xR1 xI1 + // xv' = xI0 xR0 xI1 xR1 + // yv = yR0 yI0 yR1 yI1 + // yv' = yI0 yR0 yI1 yR1 + // zv = zR0 zI0 zR1 zI1 + // zv' = zI0 zR0 zI1 zR1 + // axrv = axR -axR axR -axR + // axiv = axI axI axI axI + // ayrv = ayR -ayR ayR -ayR + // ayiv = ayI ayI ayI ayI + // + // step 1: FMA zv = zv + axrv * xv + // step 2: shuffle xv -> xv' + // step 3: FMA zv = zv + axiv * xv' + // step 4: FMA zv = zv + ayrv * yv + // step 5: shuffle yv -> xyv' + // step 6: FMA zv = zv + ayiv * yv' + + __m256d alphaxRv; + __m256d alphaxIv; + __m256d alphayRv; + __m256d alphayIv; + __m256d xv[4]; + __m256d yv[4]; + __m256d zv[4]; + + double alphaxR, alphaxI; + double alphayR, alphayI; + + alphaxR = alphax->real; + alphaxI = alphax->imag; + alphayR = alphay->real; + alphayI = alphay->imag; + + // Broadcast alphax & alphay to respective vector registers + if ( !bli_is_conj( conjx ) ) // If not x conjugate + { + // alphaxRv = axR axR axR axR + // alphaxIv = -axI axI -axI axI + alphaxRv = _mm256_broadcast_sd( &alphaxR ); + alphaxIv = _mm256_set_pd( alphaxI, -alphaxI, alphaxI, -alphaxI ); + } + else + { + // alphaxRv = axR -axR axR -axR + // alphaxIv = axI axI axI axI + alphaxRv = _mm256_set_pd( -alphaxR, alphaxR, -alphaxR, alphaxR ); + alphaxIv = _mm256_broadcast_sd( &alphaxI ); + } + + if ( !bli_is_conj( conjy ) ) // If not y conjugate + { + // alphayRv = ayR ayR ayR ayR + // alphayIv = -ayI ayI -ayI ayI + alphayRv = _mm256_broadcast_sd( &alphayR ); + alphayIv = _mm256_set_pd( alphayI, -alphayI, alphayI, -alphayI ); + } + else + { + // alphayRv = ayR -ayR ayR -ayR + // alphayIv = ayI ayI ayI ayI + alphayRv = _mm256_set_pd( -alphayR, alphayR, -alphayR, alphayR ); + alphayIv = _mm256_broadcast_sd( &alphayI ); + } + + // Processing 8 elements per loop, 16 FMAs + for ( ; ( i + 7 ) < n; i += 8 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + xv[2] = _mm256_loadu_pd( x0 + 2*n_elem_per_reg ); + xv[3] = _mm256_loadu_pd( x0 + 3*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + yv[2] = _mm256_loadu_pd( y0 + 2*n_elem_per_reg ); + yv[3] = _mm256_loadu_pd( y0 + 3*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z0 + 1*n_elem_per_reg ); + zv[2] = _mm256_loadu_pd( z0 + 2*n_elem_per_reg ); + zv[3] = _mm256_loadu_pd( z0 + 3*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxRv, zv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphaxRv, zv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphaxRv, zv[3] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + xv[1] = _mm256_permute_pd( xv[1], 5 ); + xv[2] = _mm256_permute_pd( xv[2], 5 ); + xv[3] = _mm256_permute_pd( xv[3], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxIv, zv[1] ); + zv[2] = _mm256_fmadd_pd( xv[2], alphaxIv, zv[2] ); + zv[3] = _mm256_fmadd_pd( xv[3], alphaxIv, zv[3] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayRv, zv[1] ); + zv[2] = _mm256_fmadd_pd( yv[2], alphayRv, zv[2] ); + zv[3] = _mm256_fmadd_pd( yv[3], alphayRv, zv[3] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + yv[1] = _mm256_permute_pd( yv[1], 5 ); + yv[2] = _mm256_permute_pd( yv[2], 5 ); + yv[3] = _mm256_permute_pd( yv[3], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayIv, zv[1] ); + zv[2] = _mm256_fmadd_pd( yv[2], alphayIv, zv[2] ); + zv[3] = _mm256_fmadd_pd( yv[3], alphayIv, zv[3] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (z0 + 1*n_elem_per_reg), zv[1] ); + _mm256_storeu_pd( (z0 + 2*n_elem_per_reg), zv[2] ); + _mm256_storeu_pd( (z0 + 3*n_elem_per_reg), zv[3] ); + + x0 += 4*n_elem_per_reg; + y0 += 4*n_elem_per_reg; + z0 += 4*n_elem_per_reg; + } + + // Processing 4 elements per loop, 8 FMAs + for ( ; ( i + 3 ) < n; i += 4 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + xv[1] = _mm256_loadu_pd( x0 + 1*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + yv[1] = _mm256_loadu_pd( y0 + 1*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + zv[1] = _mm256_loadu_pd( z0 + 1*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxRv, zv[1] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + xv[1] = _mm256_permute_pd( xv[1], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( xv[1], alphaxIv, zv[1] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayRv, zv[1] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + yv[1] = _mm256_permute_pd( yv[1], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + zv[1] = _mm256_fmadd_pd( yv[1], alphayIv, zv[1] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + _mm256_storeu_pd( (z0 + 1*n_elem_per_reg), zv[1] ); + + x0 += 2*n_elem_per_reg; + y0 += 2*n_elem_per_reg; + z0 += 2*n_elem_per_reg; + } + + // Processing 2 elements per loop, 4FMAs + for ( ; ( i + 1 ) < n; i += 2 ) + { + // Loading x vector + // xv = xR0 xI0 xR1 xI1 + xv[0] = _mm256_loadu_pd( x0 + 0*n_elem_per_reg ); + + // Loading y vector + // yv = yR0 yI0 yR1 yI1 + yv[0] = _mm256_loadu_pd( y0 + 0*n_elem_per_reg ); + + // Loading z vector + // zv = zR0 zI0 zR1 zI1 + zv[0] = _mm256_loadu_pd( z0 + 0*n_elem_per_reg ); + + // zv = zv + alphaxRv * xv + // zv = zR0 + axR.xR0, zI0 + axR.xI0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxRv, zv[0] ); + + // Shuffling xv + // xv = xI0 xR0 xI1 xR1 + xv[0] = _mm256_permute_pd( xv[0], 5 ); + + // zv = zv + alphaxIv * xv + // zv = zR0 + axR.xR0 - axI.xI0, zI0 + axR.xI0 + axI.xR0, ... + zv[0] = _mm256_fmadd_pd( xv[0], alphaxIv, zv[0] ); + + // zv = zv + alphayRv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayRv, zv[0] ); + + // Shuffling yv + // yv = yI0 yR0 yI1 yR1 + yv[0] = _mm256_permute_pd( yv[0], 5 ); + + // zv = zv + alphayIv * yv + // zv = zR0 + axR.xR0 - axI.xI0 + ayR.yR0 - ayI.yI0, + // zI0 + axR.xI0 + axI.xR0 + ayR.yI0 + ayI.yR0, ... + zv[0] = _mm256_fmadd_pd( yv[0], alphayIv, zv[0] ); + + // Storing results from zv + _mm256_storeu_pd( (z0 + 0*n_elem_per_reg), zv[0] ); + + x0 += 1*n_elem_per_reg; + y0 += 1*n_elem_per_reg; + z0 += 1*n_elem_per_reg; + } + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + if ( !bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else if ( !bli_is_conj( conjx ) && bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else if ( bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + else + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2; + y0 += 2; + z0 += 2; + } + } + } + else + { + // Using scalar code for non-unit increments + if ( !bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else if ( !bli_is_conj( conjx ) && bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR - axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) - + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI + axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*alphax0) * (*(x0 + 1)) + + (*(alphax0 + 1)) * (*x0) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else if ( bli_is_conj( conjx ) && !bli_is_conj( conjy ) ) + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR - ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) - + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI + ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*alphay0) * (*(y0 + 1)) + + (*(alphay0 + 1)) * (*y0); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + else + { + for ( ; i < n; i++ ) + { + // zR += ( axR.xR + axI.xI + ayR.yR + ayI.yI ) + *z0 += (*alphax0) * (*x0) + + (*(alphax0 + 1)) * (*(x0 + 1)) + + (*alphay0) * (*y0) + + (*(alphay0 + 1)) * (*(y0 + 1)); + + // zI += ( axR.xI - axI.xR + ayR.yI - ayI.yR ) + *(z0 + 1) += (*(alphax0 + 1)) * (*x0) - + (*alphax0) * (*(x0 + 1)) + + (*(alphay0 + 1)) * (*y0) - + (*alphay0) * (*(y0 + 1)); + + x0 += 2 * incx; + y0 += 2 * incy; + z0 += 2 * incz; + } + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_4) +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 537e67038a..10a656835f 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -115,6 +115,7 @@ AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_5 ) AXPYF_KER_PROT( dcomplex, z, axpyf_zen_int_4 ) // axpy2v (intrinsics) AXPY2V_KER_PROT(double, d, axpy2v_zen_int ) +AXPY2V_KER_PROT(dcomplex, z, axpy2v_zen_int ) // dotxf (intrinsics) DOTXF_KER_PROT( float, s, dotxf_zen_int_8 ) From fecf7d5af44233ecfcc3fa0a096f4d4419d81753 Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 23 Feb 2022 13:11:54 +0530 Subject: [PATCH 28/63] AOCL_Windows: Updated windows build system. Removed the "target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX)" statement for the static ST library builb. This statement is not needed for static ST library build, mistakenly added. Change-Id: I577a28c75644043fd077d938bf7f51cdea8ee13d --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 8ba483e36b..1320460a29 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -548,7 +548,6 @@ if(NOT BUILD_SHARED_LIBS) set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}" STATIC_LIBRARY_OPTIONS "${OpenMP_libomp_LIBRARY}") else() set_target_properties("${PROJECT_NAME}" PROPERTIES LINKER_LANGUAGE C OUTPUT_NAME "${LIB_NAME}") - target_link_libraries("${PROJECT_NAME}" PRIVATE OpenMP::OpenMP_CXX) endif() endif() From 7158945d64864a440ec6cd0d9488c7b80917d28f Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 2 Mar 2022 04:08:26 -0600 Subject: [PATCH 29/63] dher2 API in blis make check fails on non avx2 platform - dher2 did not have avx check for platform. It was calling avx kernel regardless of platform support. Which resulted in core dump. - Added avx based platform check in both variant of dher2 for fixing the issue. AMD-Internal: [CPUPL-2043] Change-Id: I1fd1dcc9336980bfb7ffa9376f491f107c889c0b --- frame/2/her2/bli_her2_unf_var1_amd.c | 64 ++++++++++++++++++---------- frame/2/her2/bli_her2_unf_var4_amd.c | 39 +++++++++++------ 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/frame/2/her2/bli_her2_unf_var1_amd.c b/frame/2/her2/bli_her2_unf_var1_amd.c index 43a74f49cd..31667cc3e4 100644 --- a/frame/2/her2/bli_her2_unf_var1_amd.c +++ b/frame/2/her2/bli_her2_unf_var1_amd.c @@ -249,9 +249,13 @@ void bli_dher2_unf_var1 PASTECH(d,axpy2v_ker_ft) kfp_2v; /* Query the context for the kernel function pointer. */ - kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - if( (incx == 1) && (incy == 1) && (rs_ct == 1)) + if ( (bli_cpuid_is_avx_supported() == TRUE) + && (incx == 1) + && (incy == 1) + && (rs_ct == 1)) { for ( i = 0; i < m; ) { @@ -265,29 +269,43 @@ void bli_dher2_unf_var1 if((n_behind >= 3)) { - bli_dher2_trans_zen_int_4(c10t, x0, y0, &alpha0, n_behind + 1, cs_ct); + bli_dher2_trans_zen_int_4(c10t, x0, y0, + &alpha0, + n_behind + 1, + cs_ct); i+=4; } else { - /* Apply conjx and/or conjy to chi1 and/or psi1. */ - PASTEMAC(d,copycjs)( conjx, *chi1, conjx0_chi1 ); - PASTEMAC(d,copycjs)( conjy, *psi1, conjy1_psi1 ); - PASTEMAC(d,copycjs)( conj0, *psi1, conjy0_psi1 ); - - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, conjx0_chi1, alpha0_chi1 ); - PASTEMAC(d,scal2s)( alpha1, conjy1_psi1, alpha1_psi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have already been conjugated, if needed, + /* Apply conjx and/or conjy to chi1 + * and/or psi1. */ + PASTEMAC(d,copycjs)( conjx, *chi1, + conjx0_chi1 ); + PASTEMAC(d,copycjs)( conjy, *psi1, + conjy1_psi1 ); + PASTEMAC(d,copycjs)( conj0, *psi1, + conjy0_psi1 ); + + /* Compute scalars for vector + * subproblems. */ + PASTEMAC(d,scal2s)( alpha0, + conjx0_chi1, + alpha0_chi1 ); + PASTEMAC(d,scal2s)( alpha1, + conjy1_psi1, + alpha1_psi1 ); + + /* Compute alpha * chi1 * conj(psi1) + * after both chi1 and psi1 have + * already been conjugated, if needed * by conjx and conjy. */ - PASTEMAC(d,scal2s)( alpha0_chi1, conjy0_psi1, - alpha0_chi1_psi1 ); + PASTEMAC(d,scal2s)( alpha0_chi1, + conjy0_psi1, + alpha0_chi1_psi1 ); - /* c10t = c10t + alpha * chi1 * y0'; */ - /* c10t = c10t + conj(alpha) * psi1 * x0'; */ + /* c10t = c10t + alpha * chi1 * y0';*/ + /* c10t = c10t + conj(alpha) * psi1 * x0';*/ kfp_2v ( conj0, @@ -301,10 +319,12 @@ void bli_dher2_unf_var1 cntx ); - /* gamma11 = gamma11 + alpha * chi1 * conj(psi1) - + conj(alpha) * psi1 * conj(chi1); */ - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + /* gamma11 = gamma11 + alpha * chi1 *conj(psi1) + * + conj(alpha) * psi1 * conj(chi1);*/ + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); i+=1; } diff --git a/frame/2/her2/bli_her2_unf_var4_amd.c b/frame/2/her2/bli_her2_unf_var4_amd.c index 4d77397cd2..6e999be7d1 100644 --- a/frame/2/her2/bli_her2_unf_var4_amd.c +++ b/frame/2/her2/bli_her2_unf_var4_amd.c @@ -246,9 +246,13 @@ void bli_dher2_unf_var4 PASTECH(d,axpy2v_ker_ft) kfp_2v; /* Query the context for the kernel function pointer. */ + if ( cntx == NULL ) cntx = bli_gks_query_cntx(); kfp_2v = bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPY2V_KER, cntx ); - if((incx == 1) && (incy == 1) && (rs_ct == 1)) + if ( (bli_cpuid_is_avx_supported() == TRUE) + && (incx == 1) + && (incy == 1) + && (rs_ct == 1)) { for ( i = 0; i < m; ) { @@ -262,23 +266,28 @@ void bli_dher2_unf_var4 if((n_ahead >= 3)) { - bli_dher2_zen_int_4(gamma11, chi1, psi1, &alpha0, n_ahead + 1, cs_ct); + bli_dher2_zen_int_4(gamma11, chi1, + psi1, &alpha0, + n_ahead + 1, cs_ct); i+= 4; } else { - /* Compute scalars for vector subproblems. */ - PASTEMAC(d,scal2s)( alpha0, *psi1, alpha0_psi1 ); - PASTEMAC(d,scal2s)( alpha0, *chi1, alpha1_chi1 ); - - /* Compute alpha * chi1 * conj(psi1) after both chi1 - * and psi1 have - already been conjugated, if needed, by conjx and - conjy. */ + /* Compute scalars for vector + * subproblems. */ + PASTEMAC(d,scal2s)( alpha0, *psi1, + alpha0_psi1 ); + PASTEMAC(d,scal2s)( alpha0, *chi1, + alpha1_chi1 ); + + /* Compute alpha * chi1 * conj(psi1) + * after both chi1 and psi1 have + * already been conjugated, if needed, + * by conjx and conjy. */ PASTEMAC(d,scal2s)( alpha0_psi1, *chi1, - alpha0_chi1_psi1 ); + alpha0_chi1_psi1 ); - /* c21 = c21 + alpha * x2 * conj(psi1); */ + /* c21 = c21 + alpha * x2 * conj(psi1)*/ /* c21 = c21 + conj(alpha) * y2 * conj(chi1); */ kfp_2v @@ -295,8 +304,10 @@ void bli_dher2_unf_var4 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); - PASTEMAC(d,adds)( alpha0_chi1_psi1, *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); + PASTEMAC(d,adds)( alpha0_chi1_psi1, + *gamma11 ); i+=1; } } From a9dbab13eec7a4916682f022275a0e605267fed2 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 7 Mar 2022 14:38:08 +0530 Subject: [PATCH 30/63] Updated Windows build system to pick AMD specific sources. The framework cleanup was done for linux as part of f63f78d7 Removed Arch specific code from BLIS framework. This commit adds changes needed for windows build. AMD-Internal: [CPUPL-2052] Change-Id: Ibd503a0adeea66850de156fb95657b124e1c4b9d --- .gitignore | 10 +++++ CMakeLists.txt | 4 -- frame/2/gemv/CMakeLists.txt | 20 +++++++-- frame/2/hemv/CMakeLists.txt | 21 +++++++-- frame/2/her2/CMakeLists.txt | 21 +++++++-- frame/2/trsv/CMakeLists.txt | 21 +++++++-- frame/3/CMakeLists.txt | 18 +++++++- frame/3/gemm/CMakeLists.txt | 19 +++++++- frame/compat/CMakeLists.txt | 44 ++++++++++++++----- frame/compat/bla_gemm_amd.c | 4 +- kernels/zen/1f/CMakeLists.txt | 3 +- ...xpyf_int_8.c => bli_dotxaxpyf_zen_int_8.c} | 0 kernels/zen/2/CMakeLists.txt | 1 + 13 files changed, 153 insertions(+), 33 deletions(-) rename kernels/zen/1f/{bli_dotxaxpyf_int_8.c => bli_dotxaxpyf_zen_int_8.c} (100%) diff --git a/.gitignore b/.gitignore index b3b811654a..539f959076 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,13 @@ out.* GPATH GRTAGS GTAGS + +# Windows Build +build/* +bin/* +*.dll +*.lib +*.pdb +*.exe + +.vscode diff --git a/CMakeLists.txt b/CMakeLists.txt index 1320460a29..c7c3f31395 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,20 +34,17 @@ endif () if(${AOCL_BLIS_FAMILY} STREQUAL "zen") add_definitions(-DBLIS_FAMILY_ZEN) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen2") add_definitions(-DBLIS_FAMILY_ZEN2) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN2) add_definitions(-DBLIS_KERNELS_ZEN) add_definitions(-DBLIS_KERNELS_HASWELL) elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") add_definitions(-DBLIS_FAMILY_ZEN3) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN3) add_definitions(-DBLIS_KERNELS_ZEN2) @@ -56,7 +53,6 @@ elseif (${AOCL_BLIS_FAMILY} STREQUAL "zen3") elseif (${AOCL_BLIS_FAMILY} STREQUAL "amdzen") set(AOCL_BLIS_ZEN FALSE) add_definitions(-DBLIS_FAMILY_AMDZEN) - add_definitions(-DBLIS_CONFIG_EPYC) add_definitions(-DBLIS_CONFIG_ZEN3) add_definitions(-DBLIS_CONFIG_ZEN2) add_definitions(-DBLIS_CONFIG_ZEN) diff --git a/frame/2/gemv/CMakeLists.txt b/frame/2/gemv/CMakeLists.txt index 86be8ddc08..2f75a00f63 100644 --- a/frame/2/gemv/CMakeLists.txt +++ b/frame/2/gemv/CMakeLists.txt @@ -1,11 +1,25 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unb_var2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var2.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_unf_var1.c + ) +endif() diff --git a/frame/2/hemv/CMakeLists.txt b/frame/2/hemv/CMakeLists.txt index 677c253271..34820c3762 100644 --- a/frame/2/hemv/CMakeLists.txt +++ b/frame/2/hemv/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,10 +6,25 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unb_var4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1a.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3a.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_hemv_unf_var3.c + ) +endif() \ No newline at end of file diff --git a/frame/2/her2/CMakeLists.txt b/frame/2/her2/CMakeLists.txt index 1b4c264443..83629df8f5 100644 --- a/frame/2/her2/CMakeLists.txt +++ b/frame/2/her2/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,8 +6,23 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unb_var4.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_unf_var4.c + ) +endif() \ No newline at end of file diff --git a/frame/2/trsv/CMakeLists.txt b/frame/2/trsv/CMakeLists.txt index 1d16769d32..b07389340e 100644 --- a/frame/2/trsv/CMakeLists.txt +++ b/frame/2/trsv/CMakeLists.txt @@ -1,11 +1,26 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unb_var2.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_var_oapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var1.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsv_unf_var2.c + ) +endif() diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index 4b7711ed4e..b3aaf2c8c8 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -12,7 +12,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_packm.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_prune.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_a.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_b.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_packm_var.c @@ -27,6 +26,21 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_oapi.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_tapi.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR + ${TARGET_ARCH} STREQUAL zen2 OR + ${TARGET_ARCH} STREQUAL zen3 OR + ${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ) +endif() set(SUBDIRECTORIES "gemm" "hemm" "her2k" "herk" "symm" "syr2k" "syrk" "trmm" "trmm3" "trsm" "gemmt") diff --git a/frame/3/gemm/CMakeLists.txt b/frame/3/gemm/CMakeLists.txt index 8eb115d1f0..825dd745ca 100644 --- a/frame/3/gemm/CMakeLists.txt +++ b/frame/3/gemm/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -6,7 +6,6 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_blk_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_blk_var3.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_cntl.c - ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_int.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ker_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_ker_var2.c @@ -16,4 +15,20 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_packab.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemm_front.c + ) +endif() + add_subdirectory(ind) diff --git a/frame/compat/CMakeLists.txt b/frame/compat/CMakeLists.txt index 7c20f5100c..48b66acbcb 100644 --- a/frame/compat/CMakeLists.txt +++ b/frame/compat/CMakeLists.txt @@ -1,17 +1,12 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE -${CMAKE_CURRENT_SOURCE_DIR}/bla_amax.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amin.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_asum.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_copy.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm3m.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemmt.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_ger.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_hemm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_hemv.c @@ -20,8 +15,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_her2.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_her2k.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_herk.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_nrm2.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_scal.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_swap.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_symm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_symv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_syr.c @@ -30,7 +23,6 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_syr2k.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_syrk.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trmv.c -${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsv.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm_batch.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpby.c @@ -40,6 +32,38 @@ ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatcopy2.c ${CMAKE_CURRENT_SOURCE_DIR}/bla_omatadd.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amax_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_copy_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_scal_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_swap_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bla_amax.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_axpy.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_copy.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_dot.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemm.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_gemv.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_scal.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_swap.c + ${CMAKE_CURRENT_SOURCE_DIR}/bla_trsm.c + ) +endif() + #Add all subdirectories # add_subdirectory(attic) # add_subdirectory(blis) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 7ef58bfb35..197cc3e235 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -798,7 +798,7 @@ INSERT_GENTFUNC_BLAS_SC( gemm, gemm ) // Observed a regression in dgemm with this function addition. // Disabling temporarily. -#if 0 +#if 1 void dzgemm_ ( const f77_char* transa, @@ -875,7 +875,7 @@ void dzgemm_ bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - bli_obj_init_finish( dt_a, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); diff --git a/kernels/zen/1f/CMakeLists.txt b/kernels/zen/1f/CMakeLists.txt index 4b9caa40b6..3a77f69ef1 100644 --- a/kernels/zen/1f/CMakeLists.txt +++ b/kernels/zen/1f/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -8,4 +8,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_4.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpyf_zen_int_6.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_axpy2v_zen_int.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_dotxaxpyf_zen_int_8.c ) diff --git a/kernels/zen/1f/bli_dotxaxpyf_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c similarity index 100% rename from kernels/zen/1f/bli_dotxaxpyf_int_8.c rename to kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c diff --git a/kernels/zen/2/CMakeLists.txt b/kernels/zen/2/CMakeLists.txt index f20d114781..d4ad0143ed 100644 --- a/kernels/zen/2/CMakeLists.txt +++ b/kernels/zen/2/CMakeLists.txt @@ -4,6 +4,7 @@ target_sources("${PROJECT_NAME}" PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_ref.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_her2_zen_int_4.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemv_zen_int_4.c ) From ee02ccd039fc80675402682bca6cd94ea2492df8 Mon Sep 17 00:00:00 2001 From: mkurumel Date: Fri, 18 Feb 2022 16:00:13 +0530 Subject: [PATCH 31/63] DGEMMT : Tuning SUP threshold to improve ST and MT performance. Details : - SUP Threshold change for native vs SUP - Improved the ST performances for sizes n<800 - Introduce PACKB in SUP to improve ST performance between 320 320) && (k > 50)) + bli_rntm_set_pack_b( 1, rntm ); + } + } @@ -317,6 +326,14 @@ err_t bli_gemmtsup_int // new ways of parallelism value for the jc loop. bli_rntm_set_ways_only( jc_new, 1, ic_new, 1, 1, rntm ); bli_l3_sup_thrinfo_update_root( rntm, thread ); + + /* Enable packing for A matrix for higher sizes. Note that pack A + * * becomes pack B inside var2m because this is transpose case*/ + if(bli_is_double(dt) && (n_threads==1)) + { + if((m > 320) && (k > 50)) + bli_rntm_set_pack_a( 1, rntm ); + } } diff --git a/kernels/zen/util/bli_thresh_funcs_zen.c b/kernels/zen/util/bli_thresh_funcs_zen.c index 1b5fc86998..2786f00e43 100644 --- a/kernels/zen/util/bli_thresh_funcs_zen.c +++ b/kernels/zen/util/bli_thresh_funcs_zen.c @@ -37,16 +37,31 @@ // -- gemmt specific function bool bli_cntx_gemmtsup_thresh_is_met_zen( obj_t* a, obj_t* b, obj_t* c, cntx_t* cntx ) { - num_t dt = bli_obj_dt( c ); + num_t dt = bli_obj_dt( c ); + dim_t n = bli_obj_length( c ); + dim_t k = bli_obj_width_after_trans( a ); + rntm_t rntm; - dim_t n = bli_obj_length( c ); - dim_t k = bli_obj_width_after_trans( a ); + bli_rntm_init_from_global( &rntm ); + + // Query the number of threads from rntm object. + const dim_t n_threads = bli_rntm_num_threads( &rntm ); if( bli_is_double( dt )) { - if ( n < 300 ) return TRUE; - if ( (k / n ) > 50 ) return TRUE; - + if( n_threads == 16) + { + /*Push sizes for n<1200 into SUP path*/ + if ( n < 1200 ) return TRUE; + /*For 12005 , With packing , Native path performance is better */ + if ( n < 1600 && (n / k) < 5) return TRUE; + } + else + { + if ( n < 800 ) return TRUE; + if ( (k / n ) > 50 ) return TRUE; + } return FALSE; } else if ( bli_is_dcomplex( dt ) ) From 613fb6a90ad4d4800a260134cb006d7236b5d360 Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Tue, 15 Mar 2022 16:33:55 +0530 Subject: [PATCH 32/63] Runtime Thread Control using OpenMP API Details: - During runtime, Application can set the desired number of threads using standard OpenMP API omp_set_num_threads(nt). - BLIS Library uses standard OpenMP API omp_get_max_threads() internally, to fetch the latest value set by the application. - This value will be used to decide the number of threads in the subsequent BLAS calls. - At the time of BLIS Initialization, BLIS_NUM_THREADS environment variable will be given precedence, over the OpenMP standard API omp_set_num_threads(nt) and OMP_NUM_THREADS environment variable. - Order of precedence followed during BLIS Initialization is as follows 1. Valid value of BLIS_NUM_THREADS 2. omp_set_num_threads(nt) 3. valid value of OMP_NUM_THREADS 4. Number of cores - After BLIS initialization, if the Application issues omp_set_num_threads(nt) during runtime, number of threads set during BLIS Initialization, is overridden by the latest value set by the Application. - Existing precedence of BLIS_*_NT environment variables and the decision of optimal number of threads over the number of threads derived from the above process remains as it is. AMD-Internal: [CPUPL-2076] Change-Id: I935ba0246b1c256d0fee7d386eac0f5940fabff8 --- frame/base/bli_rntm.c | 14 +++++++++++++ frame/thread/bli_thread.c | 44 +++++++++++++++++++++++++++++++-------- 2 files changed, 49 insertions(+), 9 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index dc0acf6bf9..7176dacc4e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -49,9 +49,23 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // We must ensure that global_rntm has been initialized. bli_init_once(); + // Fetch the number of threads based on the order of precedence, + // or the latest value of number of threads, + // if set by the Application using omp_set_num_threads(nt) API. +#ifdef BLIS_ENABLE_OPENMP + dim_t n_threads = omp_get_max_threads(); +#endif + // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); + // Update the latest value of number of threads into global rntm structure, + // before copying into local rntm structure. This updated value will be + // used in the subsequent parallel regions. +#ifdef BLIS_ENABLE_OPENMP + global_rntm.num_threads = n_threads; +#endif + *rntm = global_rntm; // Release the mutex protecting global_rntm. diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index 159a9e802e..f570bcc2d8 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1633,20 +1633,46 @@ void bli_thread_init_rntm_from_env // Try to read BLIS_NUM_THREADS first. nt = bli_env_get_var( "BLIS_NUM_THREADS", -1 ); - // If BLIS_NUM_THREADS was not set, try to read OMP_NUM_THREADS. - if ( nt == -1 ) - nt = bli_env_get_var( "OMP_NUM_THREADS", -1 ); #ifdef BLIS_ENABLE_OPENMP - // If both environment variables are not set - - // number of threads can also be set by the application by calling omp_set_num_threads(nt) - // The next parallel region when encountered will run with number of threads set by the above API. - // We can know about the number of threads by using the API "omp_get_max_threads()" - if (nt == -1) nt = omp_get_max_threads(); - // If application is multithreaded and number of threads is set using omp_set_num_threads(nt) + + // Scenarios: + // 1. If BLIS_NUM_THREADS is set with valid value, set the nt using omp_set_num_threads(nt) + // so that this value can be fetched inside BLIS API as well. + // 2. If BLIS_NUM_THREADS is not set, then if Application is multithreaded and issued + // omp_set_num_threads(nt) with desired number of threads, + // omp_get_max_threads() API will fetch the number of threads set earlier. + // 3. If BLIS_NUM_THREADS is not set, omp_set_num_threads(nt) is not called by the application, + // but only OMP_NUM_THREADS is set, + // omp_get_max_threads() API will fetch the value of OMP_NUM_THREADS. + // 4. If both environment variables are not set, or if they are set with invalid values, and + // omp_set_num_threads(nt) is not issued by application, + // omp_get_max_threads() API will return the number of the cores in the current context. + // // BLIS will rntm->num_threads will also get initialized with the same value. // However if omp_set_nested is false - BLIS APIs called from parallel threads will run in sequential. // But if nested parallelism is enabled - Then each application will launch MT BLIS. + // + // Order of precedence used for number of threads: + // 1. valid value set for BLIS_NUM_THREADS environment variable + // 2. omp_set_num_threads(nt) issued by the application + // 3. valid value set for OMP_NUM_THREADS environment variable + // 4. Number of cores + // + // Note: If nt is not a valid value for omp_set_num_threads(nt) API, number of threads would be set to 1. + // omp_get_max_threads() API will return 1. + // + // OMP_NUM_THREADS environment variable is applicable only when OpenMP is enabled. + + if(nt > 0) + { + omp_set_num_threads(nt); + } + else + { + nt = omp_get_max_threads(); + } + #endif // Read the environment variables for the number of threads (ways // of parallelism) for each individual loop. From 1b301e8db74bc5ee30599be3a95026105b0fac0a Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 16 Mar 2022 11:51:05 +0530 Subject: [PATCH 33/63] AOCL-Windows: Added logic in the windows build system to generate cblas.h at configure time. AMD-Internal: [CPUPL-2037] Change-Id: Ie4ffd1d655079c895878f96dbb6f811547ad953d --- CMakeLists.txt | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index c7c3f31395..0483435679 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -521,6 +521,23 @@ execute_process( OUTPUT_VARIABLE CMD_OUTPUT) message( STATUS "Generating monolithic header file :" ${CMD_OUTPUT}) +# Logic to generate the cblas.h in include folder. +set(CBLAS_H "cblas.h") +# Arguements for python script +set(C_COMMENT "-c") +set(VERBOSE "-v1") +set(INPUT "${CMAKE_SOURCE_DIR}/frame/compat/cblas/src/${CBLAS_H}") +set(OUTPUT "${CMAKE_SOURCE_DIR}/include/${TARGET_ARCH}/${CBLAS_H}") +set(TEMP_DIR "${INCLUDE}") +set(DIR_H_PATH "${HEADER_PATH}") + +# Run python script to generate monolithic header at configuration time +execute_process( + COMMAND ${PYTHON_EXE} ${FLATTEN_PY} "${C_COMMENT}" "${VERBOSE}" "${INPUT}" "${OUTPUT}" "${TEMP_DIR}" "${DIR_H_PATH}" + RESULT_VARIABLE CMD_RESULT + OUTPUT_VARIABLE CMD_OUTPUT) +message( STATUS "Generating monolithic cblas header file :" ${CMD_OUTPUT}) + # setting the blis version string file (STRINGS "version" BLIS_VERSION) set(BLIS_VERSION_STRING ${BLIS_VERSION}) From 1996a674e08dbb77d7a8530aa27ed177d0854f4d Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Wed, 23 Mar 2022 10:30:14 +0530 Subject: [PATCH 34/63] Fine-tuning dynamic threading logic of DGEMM for small dimensions Description: 1. For small dimensions single threads dgemm_small performing better than dgemmsup and native paths. 2. Irrespecive of given number of threads we are redirecting into single thread dgemm_small AMD-Internal:[CPUPL-2053] Change-Id: If591152d18282c2544249f70bd2f0a8cd816b94e --- frame/compat/bla_gemm_amd.c | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 197cc3e235..2bb9126804 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -526,33 +526,31 @@ void dgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. - // if m0 is large and (n0 & k0) < 10 - SMALL GEMM - ST is better - // - #ifdef AOCL_DYNAMIC - if (nt && ((n0 > 10 ) || (k0 > 10)) ) + //For smaller sizes dgemm_small is perfoming better + if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>150)) ) #else - if (nt) + if (nt) #endif - { + { // Will call parallelized dgemm code - sup & native PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); /* Finalize BLIS. */ bli_finalize_auto(); return; - } + } // The code below will be called when number of threads = 1. From 4e1c251c719feed667d897359ce99fc25412e3d7 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Tue, 22 Mar 2022 06:59:36 -0500 Subject: [PATCH 35/63] Implement zgemm_small kernel Details: - Intrinsic implementation of zgemm_small nn kernel. - Intrinsic implementation of zgemm_small_At kernel. - Added support conjugate and hermitian transpose - Main loop operates in multiple of 4x3 tile. - Edge cases are handles separately. AMD-Internal: [CPUPL-2084] Change-Id: I512da265e4d4ceec904877544f1d15cddc147a66 --- frame/compat/bla_gemm_amd.c | 46 +- kernels/zen/3/bli_gemm_small.c | 7726 +++++++++++++++++++++++++++++++- 2 files changed, 7759 insertions(+), 13 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 2bb9126804..ff995b5f07 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -712,7 +712,12 @@ void zgemm_ //dim_t nt = bli_thread_get_num_threads(); // get number of threads bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. - if ( nt ) +#ifdef AOCL_DYNAMIC + //For smaller sizes zgemm_small is perfoming better + if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>100)) ) +#else + if (nt) +#endif { // Will call parallelized zgemm code - sup & native PASTEMAC(gemm, BLIS_OAPI_EX_SUF) @@ -733,6 +738,31 @@ void zgemm_ return; } +#ifdef BLIS_ENABLE_SMALL_MATRIX + err_t status; + + if((nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 )) + { + status = bli_gemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } +#endif // The code below will be called when number of threads = 1. #if ENABLE_INDUCED_METHOD /* 3m_sqp is optimal for certain matrix shapes. @@ -769,13 +799,13 @@ void zgemm_ // sup has been enabled for single instance cases. if(single_instance==1) { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } } // fall back on native path when zgemm is not handled in sup path. diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 4815d57d72..18745b9c3f 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -40,6 +40,7 @@ #define MR 32 #define D_MR (MR >> 1) +#define Z_MR (MR >> 3) #define NR 3 #define D_BLIS_SMALL_MATRIX_K_THRES_ROME 256 @@ -70,7 +71,26 @@ err_t bli_dgemm_small cntx_t* cntx, cntl_t* cntl ); - +static err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); +static err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); static err_t bli_sgemm_small_atbn ( obj_t* alpha, @@ -112,7 +132,7 @@ err_t bli_gemm_small #ifdef BLIS_ENABLE_MULTITHREADING AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); - return BLIS_NOT_YET_IMPLEMENTED; + return BLIS_NOT_YET_IMPLEMENTED; #else // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. @@ -152,6 +172,18 @@ err_t bli_gemm_small return bli_dgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif } + if(dt == BLIS_DCOMPLEX) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_zgemm_small_At is called directly from blas interface for + // sizes within thresholds. + // Avoinding calling of bli_zgemm_small_At from gemm_front + // and directing to native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); +#endif + } if (bli_obj_has_notrans( b )) { @@ -180,6 +212,19 @@ err_t bli_gemm_small #endif } + if (dt == BLIS_DCOMPLEX) + { +#ifndef BLIS_ENABLE_MULTITHREADING + // bli_zgemm_small is called directly from BLAS interface for sizes within thresholds. + // Avoiding calling bli_zgemm_small from gemm_front and directing to + // native implementation. + return BLIS_NOT_YET_IMPLEMENTED; +#else + return bli_zgemm_small(alpha, a, b, beta, c, cntx, cntl); +#endif + } + + if (dt == BLIS_FLOAT) { return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl); @@ -189,7 +234,6 @@ err_t bli_gemm_small return BLIS_NOT_YET_IMPLEMENTED; }; - static err_t bli_sgemm_small ( obj_t* alpha, @@ -2865,7 +2909,6 @@ static err_t bli_sgemm_small if (m_remainder >= 4) { - //printf("HERE\n"); m_remainder -= 4; for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) @@ -5377,7 +5420,6 @@ err_t bli_dgemm_small_At if (m_remainder >= 4) { - //printf("HERE\n"); m_remainder -= 4; tA = A + row_idx * lda; @@ -5705,5 +5747,7679 @@ err_t bli_dgemm_small_At return BLIS_NONCONFORMAL_DIMENSIONS; } }; + + +#define BLIS_SET_YMM_REG_ZEROS \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + ymm16 = _mm256_setzero_pd(); \ + ymm17 = _mm256_setzero_pd(); \ + ymm18 = _mm256_setzero_pd(); \ + ymm19 = _mm256_setzero_pd(); \ + ymm20 = _mm256_setzero_pd(); \ + ymm21 = _mm256_setzero_pd(); \ + + +#define BLIS_SET_ALL_YMM_REG_ZEROS \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm11 = _mm256_setzero_pd(); \ + ymm12 = _mm256_setzero_pd(); \ + ymm13 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + + + +static err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + // number of columns of OP(A), will be updated if OP(A) is Transpose(A) + gint_t K = bli_obj_width( a ); + gint_t L = M * N; + + if(L && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 4.(M%4) + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. However, using this global array make the function + * non-reentrant. Instead of using a global array we should allocate + * buffer for each invocation. Since the buffer size is too big or stack + * and doing malloc every time will be too expensive, better approach is + * to get the buffer from the pre-allocated pool and it the pool once we + * are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive the memory broker (via rntm). Following hack will get the + * global memory broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + + if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) + { + required_packing_A = 0; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for Z_MRxN columns of C matrix, thus + * accessing the Z_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension Z_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + /** + * This is the part of the pack and compute optimization. + * During the first column iteration, we store the accessed A + * matrix into contiguous static memory. This helps to keep te A + * matrix in Cache and aviods the TLB misses. + */ + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = D_A_pack; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B + // matrix i data and multiplies it with + // the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *) + (tA_packed + 2), ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) * + 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; + col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts + // the B matrix data and multiplies it + // with the A matrix. This loop is + // processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts the + // B matrix data and multiplies it with + // the A matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + m_remainder = M - row_idx; + + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; + +static err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) + + + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + + if( M && N && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + if( bli_obj_has_trans( b ) ) + { + tb_inc_col = 1; // switch row and column strides + tb_inc_row = ldb; + } + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each + * invocation. + * Since the buffer size is too big or stack and doing malloc every time + * will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and + * return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) + { + required_packing_A = 0; + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); #endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = Z_MR; + + // Pack 16xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 2; x += 1) + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 2 * lda; + tA_packed = D_A_pack + (x + 1)*2; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + + // Process NR columns of C matrix at a time. + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + + m_remainder = M - row_idx; + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 3; + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + ymm3 = _mm256_loadu_pd((double const *) + (tA_temp + 2 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + xmm0 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed + 2), + xmm0); + + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + xmm0 = _mm256_extractf128_pd(ymm3, 1); + _mm_storeu_pd((double *) + (tA_packed + 1 * lda_packed + 2), + xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + tA_packed[2].real = tA_temp[2 * lda].real; + tA_packed[2].imag = tA_temp[2 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 3; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 2; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 1; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + + xmm0 = _mm256_extractf128_pd(ymm0, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + xmm0); + + xmm0 = _mm256_extractf128_pd(ymm0, 1); + _mm_storeu_pd((double *)(tA_packed + 1 + * lda_packed), xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 1; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_dgemm_small_At(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for dgemm_small_At." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; +#endif From d01d19c7fbdc7e93f72f93fb5051f26c866619de Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Wed, 30 Mar 2022 07:16:24 -0500 Subject: [PATCH 36/63] Fixed ztrsm computational failure - Fixed memory access for edge cases such that all load are within memory boundary only. - Corrected ztrsm utility APIs for dcomplex multiplication and division. AMD-Internal: [CPUPL-2093] Change-Id: Ib2c65e7921f6391b530cd20d6ea6b50f24bd705e --- kernels/zen/3/bli_trsm_small.c | 771 ++++++++++++++++++++++++--------- 1 file changed, 567 insertions(+), 204 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 0fa8f66d5a..32b7647a50 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -3891,33 +3891,20 @@ err_t bli_trsm_small */ #define DCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - double dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ +/* dcomplex inva = {1.0, 0.0};*/\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_zinvscals(b, a);\ } #define DCOMPLEX_MUL(a, b, c) {\ - double real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - double imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_zscals(a,c);\ } #define DCOMPLEX_DIV(a, b){\ - double dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_zinvscals(b,a); \ } @@ -3946,11 +3933,8 @@ err_t bli_trsm_small #define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ if(!is_unitdiag)\ {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - DCOMPLEX_MUL(c, a, c)\ - DCOMPLEX_DIV(c, b)\ - }\ + bli_zinvscals(b, c);\ + }\ } #endif @@ -4299,6 +4283,213 @@ BLIS_INLINE err_t ztrsm_AuXB_ref _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\ } + +#define BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm16);\ + ymm1 = _mm256_mul_pd(ymm1, ymm16);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \ + \ + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + \ + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\ + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\ + \ + tptr += 2; \ + a10 += p_lda; \ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ + ymm14 = _mm256_permute_pd(ymm14, 0x5);\ + ymm15 = _mm256_permute_pd(ymm15, 0x5);\ + \ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm11 = _mm256_addsub_pd(ymm11, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm7);\ + ymm10 = _mm256_addsub_pd(ymm10, ymm14);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm15);\ +} + + +#define BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double * )b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \ + \ + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\ + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\ + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm6 = _mm256_permute_pd(ymm6, 0x5);\ + ymm7 = _mm256_permute_pd(ymm7, 0x5);\ +\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ + ymm9 = _mm256_addsub_pd(ymm9, ymm6);\ + ymm13 = _mm256_addsub_pd(ymm13, ymm7);\ +} + +#define BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\ + double *tptr = (double *)b01;\ + if(conjtransa) {\ + ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm0 = _mm256_mul_pd(ymm0, ymm18);\ + ymm1 = _mm256_mul_pd(ymm1, ymm18);\ + \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + else {\ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \ + ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\ + ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \ + \ + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\ + \ + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\ + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\ + tptr += 2; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + }\ + }\ + ymm4 = _mm256_permute_pd(ymm4, 0x5);\ + ymm5 = _mm256_permute_pd(ymm5, 0x5);\ + ymm8 = _mm256_addsub_pd(ymm8, ymm4);\ + ymm12 = _mm256_addsub_pd(ymm12, ymm5);\ +} + + /** * Performs GEMM operation. * Two elements of column in ymm0 @@ -31943,75 +32134,160 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB if(m_rem == 3) { dim_t p_lda = 4; - if(transa) - { - for(dim_t x = 0; x < i; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm10 = _mm256_loadu_pd((double const *) - (a10 + 2)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - ymm11 = _mm256_loadu_pd((double const *) - (a10 + 2 + cs_a)); + if(transa) + { + dim_t x = 0; + for(x = 0; (x+3) < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm10 = _mm256_loadu_pd((double const *) + (a10 + 2)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + ymm11 = _mm256_loadu_pd((double const *) + (a10 + 2 + cs_a)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm10 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm10,ymm1,0x31); + + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*3 + 2), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+2) < i; x += 3) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); + xmm4 = _mm_loadu_pd((double const *) + (a10 + 2 + cs_a)); + ymm11 = _mm256_insertf128_pd(ymm11, xmm4, 0); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2), ymm8); + + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + xmm4 = _mm_loadu_pd((double const *)(a10 + + 2 * cs_a + 2)); + ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0); + ymm1 = _mm256_set_pd(1, 1, 1, 1); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20); + + + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda*2 + 2), ymm8); + + a10 += 3; + ptr_a10_dup += p_lda * p_lda; + } + for(; (x+1) < i; x += 2) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *) + (a10 + cs_a)); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda), ymm7); - ymm0 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a)); - ymm10 = _mm256_loadu_pd((double const *)(a10 - + 2 * cs_a + 2)); + ymm0 = _mm256_loadu_pd((double const *)(a10 + + 2 * cs_a)); + ymm1 = _mm256_set_pd(1, 1, 1, 1); - ymm1 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(a10 - + 3 * cs_a + 2)); + ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20); - ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup + 2), - ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda + 2), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*2 + 2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda*3 + 2), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + + p_lda + 2), ymm7); - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } + a10 += 2; + ptr_a10_dup += p_lda * p_lda; + } + for(; x < i; x += 1) + { + xmm4 = _mm_loadu_pd((double const *)(a10)); + xmm5 = _mm_loadu_pd((double const *) + (a10 + cs_a)); - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { @@ -33429,37 +33791,38 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB } else if(m_remainder == 1) { - dim_t p_lda = 2; // packed leading dimension - if(transa) - { - for(dim_t x = 0; x < m-m_remainder; x += p_lda) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *) - (a10 + cs_a)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + - p_lda), ymm7); - - a10 += p_lda; - ptr_a10_dup += p_lda * p_lda; - } - - } - else - { - for(dim_t x=0;x 0; j -= d_nr) { From 4d972e219b1800c5999cebcc6185ad5bc39a8dd6 Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Mon, 4 Apr 2022 16:08:18 +0530 Subject: [PATCH 37/63] Performance Improvement for ztrsm small sizes Details: - Enable ztrsm small implementation - For small sizes, Right Variants and Left Unit Diag Variants are using ztrsm_small implementations. - Optimization of Left Non-Unit Diagonal Variants, Work In Progress AMD-Internal: [SWLCSG-1194] Change-Id: Ib3cce6e2e4ac0817ccd4dff4bb0fa4a23e231ca4 --- frame/compat/bla_trsm_amd.c | 11 ++++++----- frame/include/bli_gentfunc_macro_defs.h | 6 +++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 21b2a1598d..eb5c835ff5 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -902,7 +902,6 @@ void dtrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#if 0 void ztrsm_ ( const f77_char* side, @@ -1184,8 +1183,10 @@ void ztrsm_ * In case of multithread when [m,n]<=128 sinlge thread implemenation * is doing better than native multithread */ bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=500 && n0<=500) || - (nt && (m0+n0)<128) ) + + if((blis_side == BLIS_RIGHT) || (blis_diaga == BLIS_UNIT_DIAG)) { + if(((nt==0) && (m0<=500) && (n0<=500)) || + (nt && ((m0+n0)<128))) { err_t status; status = bli_trsm_small @@ -1205,6 +1206,7 @@ void ztrsm_ return; } } + } #endif bli_trsmnat @@ -1221,7 +1223,6 @@ void ztrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#endif #if 0 void ctrsm_ ( @@ -1539,6 +1540,6 @@ void ctrsm_ bli_finalize_auto(); } #endif -INSERT_GENTFUNC_BLAS_CZ( trsm, trsm ) +INSERT_GENTFUNC_BLAS_C( trsm, trsm ) #endif diff --git a/frame/include/bli_gentfunc_macro_defs.h b/frame/include/bli_gentfunc_macro_defs.h index 1bac7aa7c4..49c79cb8ae 100644 --- a/frame/include/bli_gentfunc_macro_defs.h +++ b/frame/include/bli_gentfunc_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -67,6 +67,10 @@ GENTFUNC( scomplex, c, blasname, blisname ) GENTFUNC( scomplex, c, blasname, blisname ) \ GENTFUNC( dcomplex, z, blasname, blisname ) +#define INSERT_GENTFUNC_BLAS_C( blasname, blisname ) \ +\ +GENTFUNC( scomplex, c, blasname, blisname ) + // -- Basic one-operand macro with real domain only -- From 3ff181008a617c8b220ae2cf6667206f61a75766 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 1 Apr 2022 09:15:14 +0530 Subject: [PATCH 38/63] Changes to enable zgemm small from BLAS Layer 1. Removed small gemm call from native path to avoid Single threaded calls as a part of MultiThreaded scenarios. 2. SUP and INDUCED Method path disabled. 3. Added AOCL Dynamic for optimum number of threads to achieve higher performance. Change-Id: I3c41641bef4906bdbdb5f05e67c0f61e86025d92 --- frame/3/gemm/bli_gemm_front.c | 16 - frame/3/gemm/bli_gemm_front_amd.c | 26 +- frame/base/bli_rntm.c | 16 + frame/compat/bla_gemm_amd.c | 1254 ++- kernels/zen/3/bli_gemm_small.c | 15389 ++++++++++++++-------------- kernels/zen/bli_kernels_zen.h | 22 + 6 files changed, 8370 insertions(+), 8353 deletions(-) diff --git a/frame/3/gemm/bli_gemm_front.c b/frame/3/gemm/bli_gemm_front.c index 46e163c026..a9bada995d 100644 --- a/frame/3/gemm/bli_gemm_front.c +++ b/frame/3/gemm/bli_gemm_front.c @@ -74,22 +74,6 @@ void bli_gemm_front return; } -#ifdef BLIS_ENABLE_SMALL_MATRIX - // Only handle small problems separately for homogeneous datatypes. - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && - bli_obj_dt( a ) == bli_obj_dt( c ) && - bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) - { - err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - } -#endif - // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); diff --git a/frame/3/gemm/bli_gemm_front_amd.c b/frame/3/gemm/bli_gemm_front_amd.c index a29a0bb85b..34b41f0568 100644 --- a/frame/3/gemm/bli_gemm_front_amd.c +++ b/frame/3/gemm/bli_gemm_front_amd.c @@ -50,6 +50,16 @@ void bli_gemm_front AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_3); bli_init_once(); + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_dcomplex(c))// This will enable for ZGEMM + { + bli_nthreads_optimum(a, b, c, BLIS_GEMM, rntm); + } + #endif + obj_t a_local; obj_t b_local; obj_t c_local; @@ -74,22 +84,6 @@ void bli_gemm_front return; } -#ifdef BLIS_ENABLE_SMALL_MATRIX - // Only handle small problems separately for homogeneous datatypes. - if ( bli_obj_dt( a ) == bli_obj_dt( b ) && - bli_obj_dt( a ) == bli_obj_dt( c ) && - bli_obj_comp_prec( c ) == bli_obj_prec( c ) ) - { - err_t status = bli_gemm_small( alpha, a, b, beta, c, cntx, cntl ); - - if ( status == BLIS_SUCCESS ) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); - return; - } - } -#endif - // Alias A, B, and C in case we need to apply transformations. bli_obj_alias_to( a, &a_local ); bli_obj_alias_to( b, &b_local ); diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 7176dacc4e..c597074f58 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -600,6 +600,22 @@ void bli_nthreads_optimum( } } + else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) + { + + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + + if((m<=128 || n<=128 || k<=128) && (m+n+k <= 400) ) + { + n_threads_ideal = 8; + } + else if((m<=256 || n<=256 || k<=256) && (m+n+k <= 800) ) + { + n_threads_ideal = 16; + } + } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { dim_t n = bli_obj_length(c); diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index ff995b5f07..7060509de2 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -55,76 +55,76 @@ void PASTEF77(ch,blasname) \ const ftype* a, const f77_int* lda, \ const ftype* b, const f77_int* ldb, \ const ftype* beta, \ - ftype* c, const f77_int* ldc \ + ftype* c, const f77_int* ldc \ ) \ { \ - trans_t blis_transa; \ - trans_t blis_transb; \ - dim_t m0, n0, k0; \ - inc_t rs_a, cs_a; \ - inc_t rs_b, cs_b; \ - inc_t rs_c, cs_c; \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ + inc_t rs_a, cs_a; \ + inc_t rs_b, cs_b; \ + inc_t rs_c, cs_c; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ - /* Perform BLAS parameter checking. */ \ - PASTEBLACHK(blasname) \ - ( \ - MKSTR(ch), \ - MKSTR(blasname), \ - transa, \ - transb, \ - m, \ - n, \ - k, \ - lda, \ - ldb, \ - ldc \ - ); \ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ \ - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ \ - /* Typecast BLAS integers to BLIS integers. */ \ - bli_convert_blas_dim1( *m, m0 ); \ - bli_convert_blas_dim1( *n, n0 ); \ - bli_convert_blas_dim1( *k, k0 ); \ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ \ - /* Set the row and column strides of the matrix operands. */ \ - rs_a = 1; \ - cs_a = *lda; \ - rs_b = 1; \ - cs_b = *ldb; \ - rs_c = 1; \ - cs_c = *ldc; \ + /* Set the row and column strides of the matrix operands. */ \ + rs_a = 1; \ + cs_a = *lda; \ + rs_b = 1; \ + cs_b = *ldb; \ + rs_c = 1; \ + cs_c = *ldc; \ \ - /* Call BLIS interface. */ \ - PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ - ( \ - blis_transa, \ - blis_transb, \ - m0, \ - n0, \ - k0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a, \ - (ftype*)b, rs_b, cs_b, \ - (ftype*)beta, \ - (ftype*)c, rs_c, cs_c, \ - NULL, \ - NULL \ - ); \ + /* Call BLIS interface. */ \ + PASTEMAC2(ch,blisname,BLIS_TAPI_EX_SUF) \ + ( \ + blis_transa, \ + blis_transb, \ + m0, \ + n0, \ + k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, rs_b, cs_b, \ + (ftype*)beta, \ + (ftype*)c, rs_c, cs_c, \ + NULL, \ + NULL \ + ); \ \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } #else @@ -143,175 +143,175 @@ void PASTEF77(ch,blasname) \ const ftype* a, const f77_int* lda, \ const ftype* b, const f77_int* ldb, \ const ftype* beta, \ - ftype* c, const f77_int* ldc \ + ftype* c, const f77_int* ldc \ ) \ { \ \ - trans_t blis_transa; \ - trans_t blis_transb; \ - dim_t m0, n0, k0; \ + trans_t blis_transa; \ + trans_t blis_transb; \ + dim_t m0, n0, k0; \ \ - dim_t m0_a, n0_a; \ - dim_t m0_b, n0_b; \ + dim_t m0_a, n0_a; \ + dim_t m0_b, n0_b; \ \ - /* Initialize BLIS. */ \ - bli_init_auto(); \ + /* Initialize BLIS. */ \ + bli_init_auto(); \ \ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ - AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1); \ + AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(ch), *transa, *transb, *m, *n, *k, \ + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); \ \ - /* Perform BLAS parameter checking. */ \ - PASTEBLACHK(blasname) \ - ( \ - MKSTR(ch), \ - MKSTR(blasname), \ - transa, \ - transb, \ - m, \ - n, \ - k, \ - lda, \ - ldb, \ - ldc \ - ); \ + /* Perform BLAS parameter checking. */ \ + PASTEBLACHK(blasname) \ + ( \ + MKSTR(ch), \ + MKSTR(blasname), \ + transa, \ + transb, \ + m, \ + n, \ + k, \ + lda, \ + ldb, \ + ldc \ + ); \ \ - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ \ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); \ + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); \ \ - /* Typecast BLAS integers to BLIS integers. */ \ - bli_convert_blas_dim1( *m, m0 ); \ - bli_convert_blas_dim1( *n, n0 ); \ - bli_convert_blas_dim1( *k, k0 ); \ + /* Typecast BLAS integers to BLIS integers. */ \ + bli_convert_blas_dim1( *m, m0 ); \ + bli_convert_blas_dim1( *n, n0 ); \ + bli_convert_blas_dim1( *k, k0 ); \ \ - /* Set the row and column strides of the matrix operands. */ \ - const inc_t rs_a = 1; \ - const inc_t cs_a = *lda; \ - const inc_t rs_b = 1; \ - const inc_t cs_b = *ldb; \ - const inc_t rs_c = 1; \ - const inc_t cs_c = *ldc; \ + /* Set the row and column strides of the matrix operands. */ \ + const inc_t rs_a = 1; \ + const inc_t cs_a = *lda; \ + const inc_t rs_b = 1; \ + const inc_t cs_b = *ldb; \ + const inc_t rs_c = 1; \ + const inc_t cs_c = *ldc; \ \ - if( n0 == 1 ) \ - { \ - if(bli_is_notrans(blis_transa)) \ - { \ - PASTEMAC(ch,gemv_unf_var2)( \ - BLIS_NO_TRANSPOSE, \ - bli_extract_conj(blis_transb), \ - m0, k0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a,\ - (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ - (ftype*) beta, \ - c, rs_c, \ - NULL \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,gemv_unf_var1)( \ - blis_transa, \ - bli_extract_conj(blis_transb), \ - k0, m0, \ - (ftype*)alpha, \ - (ftype*)a, rs_a, cs_a, \ - (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ - (ftype*)beta, \ - c, rs_c, \ - NULL \ - ); \ - } \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - return; \ - } \ - else if( m0 == 1 ) \ - { \ - if(bli_is_notrans(blis_transb)) \ - { \ - PASTEMAC(ch,gemv_unf_var1)( \ - blis_transb, \ - bli_extract_conj(blis_transa), \ - n0, k0, \ - (ftype*)alpha, \ - (ftype*)b, cs_b, rs_b, \ - (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ - (ftype*)beta, \ - c, cs_c, \ - NULL \ - ); \ - } \ - else \ - { \ - PASTEMAC(ch,gemv_unf_var2)( \ - blis_transb, \ - bli_extract_conj(blis_transa), \ - k0, n0, \ - (ftype*)alpha, \ - (ftype*)b, cs_b, rs_b, \ - (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ - (ftype*)beta, \ - c, cs_c, \ - NULL \ - ); \ - } \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - return; \ - } \ + if( n0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transa)) \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + BLIS_NO_TRANSPOSE, \ + bli_extract_conj(blis_transb), \ + m0, k0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a,\ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*) beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transa, \ + bli_extract_conj(blis_transb), \ + k0, m0, \ + (ftype*)alpha, \ + (ftype*)a, rs_a, cs_a, \ + (ftype*)b, bli_is_notrans(blis_transb)?rs_b:cs_b, \ + (ftype*)beta, \ + c, rs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ + else if( m0 == 1 ) \ + { \ + if(bli_is_notrans(blis_transb)) \ + { \ + PASTEMAC(ch,gemv_unf_var1)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + n0, k0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + else \ + { \ + PASTEMAC(ch,gemv_unf_var2)( \ + blis_transb, \ + bli_extract_conj(blis_transa), \ + k0, n0, \ + (ftype*)alpha, \ + (ftype*)b, cs_b, rs_b, \ + (ftype*)a, bli_is_notrans(blis_transa)?cs_a:rs_a, \ + (ftype*)beta, \ + c, cs_c, \ + NULL \ + ); \ + } \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + return; \ + } \ \ - const num_t dt = PASTEMAC(ch,type); \ + const num_t dt = PASTEMAC(ch,type); \ \ - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t ao = BLIS_OBJECT_INITIALIZER; \ - obj_t bo = BLIS_OBJECT_INITIALIZER; \ - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ - obj_t co = BLIS_OBJECT_INITIALIZER; \ + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t ao = BLIS_OBJECT_INITIALIZER; \ + obj_t bo = BLIS_OBJECT_INITIALIZER; \ + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; \ + obj_t co = BLIS_OBJECT_INITIALIZER; \ \ - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); \ + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); \ \ - bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ - bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)alpha, &alphao ); \ + bli_obj_init_finish_1x1( dt, (ftype*)beta, &betao ); \ \ - bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ - bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ - bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ + bli_obj_init_finish( dt, m0_a, n0_a, (ftype*)a, rs_a, cs_a, &ao ); \ + bli_obj_init_finish( dt, m0_b, n0_b, (ftype*)b, rs_b, cs_b, &bo ); \ + bli_obj_init_finish( dt, m0, n0, (ftype*)c, rs_c, cs_c, &co ); \ \ - bli_obj_set_conjtrans( blis_transa, &ao ); \ - bli_obj_set_conjtrans( blis_transb, &bo ); \ + bli_obj_set_conjtrans( blis_transa, &ao ); \ + bli_obj_set_conjtrans( blis_transb, &bo ); \ \ - PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ - ( \ - &alphao, \ - &ao, \ - &bo, \ - &betao, \ - &co, \ - NULL, \ - NULL \ - ); \ + PASTEMAC(blisname,BLIS_OAPI_EX_SUF) \ + ( \ + &alphao, \ + &ao, \ + &bo, \ + &betao, \ + &co, \ + NULL, \ + NULL \ + ); \ \ - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ - /* Finalize BLIS. */ \ - bli_finalize_auto(); \ + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); \ + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) \ + /* Finalize BLIS. */ \ + bli_finalize_auto(); \ } #endif #ifdef BLIS_ENABLE_BLAS void dgemm_ ( - const f77_char* transa, - const f77_char* transb, - const f77_int* m, - const f77_int* n, - const f77_int* k, - const double* alpha, - const double* a, const f77_int* lda, - const double* b, const f77_int* ldb, - const double* beta, - double* c, const f77_int* ldc + const f77_char* transa, + const f77_char* transb, + const f77_int* m, + const f77_int* n, + const f77_int* k, + const double* alpha, + const double* a, const f77_int* lda, + const double* b, const f77_int* ldb, + const double* beta, + double* c, const f77_int* ldc ) { @@ -343,7 +343,7 @@ void dgemm_ ldc ); - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ bli_param_map_netlib_to_blis_trans(*transa, &blis_transa); bli_param_map_netlib_to_blis_trans(*transb, &blis_transb); @@ -361,141 +361,141 @@ void dgemm_ const inc_t rs_c = 1; const inc_t cs_c = *ldc; - // This function is invoked on all architectures including ‘generic’. - // Non-AVX platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx_supported() == FALSE) - { - // This code is duplicated below, however we don't want to move it out of - // this IF block as it will affect the performance on Zen architetures - // Also this is temporary fix which will be replaced later. - const num_t dt = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); - bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); - - bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); - bli_obj_init_finish_1x1(dt, (double *)beta, &betao); - - bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); - bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); - bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); - - bli_obj_set_conjtrans(blis_transa, &ao); - bli_obj_set_conjtrans(blis_transb, &bo); - - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - - if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == FALSE) { - bli_dgemm_ref_k1_nn( m0, n0, k0, - (double*)alpha, - (double*)a, *lda, - (double*)b, *ldb, - (double*)beta, - c, *ldc - ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS */ - bli_finalize_auto(); - - return; + // This code is duplicated below, however we don't want to move it out of + // this IF block as it will affect the performance on Zen architetures + // Also this is temporary fix which will be replaced later. + const num_t dt = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans(blis_transa, m0, k0, &m0_a, &n0_a); + bli_set_dims_with_trans(blis_transb, k0, n0, &m0_b, &n0_b); + + bli_obj_init_finish_1x1(dt, (double *)alpha, &alphao); + bli_obj_init_finish_1x1(dt, (double *)beta, &betao); + + bli_obj_init_finish(dt, m0_a, n0_a, (double *)a, rs_a, cs_a, &ao); + bli_obj_init_finish(dt, m0_b, n0_b, (double *)b, rs_b, cs_b, &bo); + bli_obj_init_finish(dt, m0, n0, (double *)c, rs_c, cs_c, &co); + + bli_obj_set_conjtrans(blis_transa, &ao); + bli_obj_set_conjtrans(blis_transb, &bo); + + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL + ); + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + + if((k0 == 1) && bli_is_notrans(blis_transa) && bli_is_notrans(blis_transb)) + { + bli_dgemm_ref_k1_nn( m0, n0, k0, + (double*)alpha, + (double*)a, *lda, + (double*)b, *ldb, + (double*)beta, + c, *ldc + ); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS */ + bli_finalize_auto(); + + return; } if (n0 == 1) { - if (bli_is_notrans(blis_transa)) - { - bli_dgemv_unf_var2( - BLIS_NO_TRANSPOSE, - bli_extract_conj(blis_transb), - m0, k0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var1( - blis_transa, - bli_extract_conj(blis_transb), - k0, m0, - (double*)alpha, - (double*)a, rs_a, cs_a, - (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, - (double*)beta, - c, rs_c, - ((void*)0) - ); - } - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - - return; + if (bli_is_notrans(blis_transa)) + { + bli_dgemv_unf_var2( + BLIS_NO_TRANSPOSE, + bli_extract_conj(blis_transb), + m0, k0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var1( + blis_transa, + bli_extract_conj(blis_transb), + k0, m0, + (double*)alpha, + (double*)a, rs_a, cs_a, + (double*)b, bli_is_notrans(blis_transb) ? rs_b : cs_b, + (double*)beta, + c, rs_c, + ((void*)0) + ); + } + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + + return; } else if (m0 == 1) { - if (bli_is_notrans(blis_transb)) - { - bli_dgemv_unf_var1( - blis_transb, - bli_extract_conj(blis_transa), - n0, k0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - else - { - bli_dgemv_unf_var2( - blis_transb, - bli_extract_conj(blis_transa), - k0, n0, - (double*)alpha, - (double*)b, cs_b, rs_b, - (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, - (double*)beta, - c, cs_c, - ((void*)0) - ); - } - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; + if (bli_is_notrans(blis_transb)) + { + bli_dgemv_unf_var1( + blis_transb, + bli_extract_conj(blis_transa), + n0, k0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + else + { + bli_dgemv_unf_var2( + blis_transb, + bli_extract_conj(blis_transa), + k0, n0, + (double*)alpha, + (double*)b, cs_b, rs_b, + (double*)a, bli_is_notrans(blis_transa) ? cs_a : rs_a, + (double*)beta, + c, cs_c, + ((void*)0) + ); + } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; } const num_t dt = BLIS_DOUBLE; @@ -527,29 +527,29 @@ void dgemm_ bool nt = bli_thread_get_is_parallel(); // Check if parallel dgemm is invoked. #ifdef AOCL_DYNAMIC - //For smaller sizes dgemm_small is perfoming better + //For smaller sizes dgemm_small is perfoming better if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>150)) ) #else if (nt) #endif { - // Will call parallelized dgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL + // Will call parallelized dgemm code - sup & native + PASTEMAC(gemm, BLIS_OAPI_EX_SUF) + ( + &alphao, + &ao, + &bo, + &betao, + &co, + NULL, + NULL ); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; } // The code below will be called when number of threads = 1. @@ -558,71 +558,71 @@ void dgemm_ //if( ((m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) && (n0 > 2)) if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || - ((n0 <= 10) && (k0 <=10)) ) + ((n0 <= 10) && (k0 <=10)) ) + { + err_t status; + if (bli_is_notrans(blis_transa)) + { + status = bli_dgemm_small( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else { - err_t status; - if (bli_is_notrans(blis_transa)) - { - status = bli_dgemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - else - { - status = bli_dgemm_small_At ( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } + status = bli_dgemm_small_At ( &alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + + return; + } } #endif //#ifdef BLIS_ENABLE_SMALL_MATRIX err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - return; - } - - // fall back on native path when dgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ - /* ( */ - /* &alphao, */ - /* &ao, */ - /* &bo, */ - /* &betao, */ - /* &co, */ - /* NULL, */ - /* NULL */ - /* ); */ - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + return; + } + + // fall back on native path when dgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + /* PASTEMAC(gemm, BLIS_OAPI_EX_SUF) */ + /* ( */ + /* &alphao, */ + /* &ao, */ + /* &bo, */ + /* &betao, */ + /* &co, */ + /* NULL, */ + /* NULL */ + /* ); */ + + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); } // end of dgemm_ void zgemm_ @@ -648,176 +648,166 @@ void zgemm_ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // default instance peformance tuning is done in zgemm. - // Single instance tuning is done based on env set. - dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); - - //dim_t nt = bli_thread_get_num_threads(); // get number of threads - bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. -#ifdef AOCL_DYNAMIC - //For smaller sizes zgemm_small is perfoming better - if (nt && (((m0 >32) || (n0>32) || (k0>32)) && ((m0+n0+k0)>100)) ) -#else - if (nt) -#endif - { - // Will call parallelized zgemm code - sup & native - PASTEMAC(gemm, BLIS_OAPI_EX_SUF) - ( - &alphao, - &ao, - &bo, - &betao, - &co, - NULL, - NULL - ); - - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt, m0_a, n0_a, (dcomplex*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // default instance peformance tuning is done in zgemm. + // Single instance tuning is done based on env set. + //dim_t single_instance = bli_env_get_var( "BLIS_SINGLE_INSTANCE", -1 ); + + //dim_t nt = bli_thread_get_num_threads(); // get number of threads + bool nt = bli_thread_get_is_parallel(); // Check if parallel zgemm is invoked. #ifdef BLIS_ENABLE_SMALL_MATRIX - err_t status; - - if((nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 )) - { - status = bli_gemm_small( &alphao, - &ao, - &bo, - &betao, - &co, - NULL, //cntx, - NULL - ); - } - - if (status == BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); - /* Finalize BLIS. */ - bli_finalize_auto(); - - return; - } + + if( ( (nt == 0) && (m0 <= 512 ) && ( n0 <= 512 ) && ( k0 <= 512 ) ) || + ( (nt == 1) && ((( m0 <= 32)||(n0 <= 32)||(k0 <=32)) && ((m0+n0+k0)<=100)) ) + ) + { + err_t status = BLIS_NOT_YET_IMPLEMENTED; + if (bli_is_notrans(blis_transa)) + { + status = bli_zgemm_small(&alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + else + { + status = bli_zgemm_small_At(&alphao, + &ao, + &bo, + &betao, + &co, + NULL, //cntx, + NULL + ); + } + + if (status == BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } #endif + // The code below will be called when number of threads = 1. -#if ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) +#if 0//ENABLE_INDUCED_METHOD + /* 3m_sqp is optimal for certain matrix shapes. + Initial study that it works well for square sizes and sizes closer to square shape. + + * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. + * Further investigation is necessary to make the usage choices more generic. */ + bool sqp_on = false; + if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) + { + sqp_on = true; + } + + // current range of sizes used for 3m_sqp to be expaned after evaluation. + if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) - { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } + { + sqp_on = true; + } + + if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) + { + //sqp algo is found better for n > 40 + if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + } #endif//ENABLE_INDUCED_METHOD -// native tuning resulted in better numbers compared to sup in constrained multi-instance -// sup has been enabled for single instance cases. - if(single_instance==1) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - - } - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); +// sup has been disabled. + if(0) + { + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) + { + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + } + + } + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); }// end of zgemm_ @@ -851,72 +841,72 @@ void dzgemm_ AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_1) AOCL_DTL_LOG_GEMM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, *MKSTR(z), *transa, *transb, *m, *n, *k, - (void*)alpha, *lda, *ldb, (void*)beta, *ldc); + (void*)alpha, *lda, *ldb, (void*)beta, *ldc); /* Perform BLAS parameter checking. */ - PASTEBLACHK(gemm) - ( - MKSTR(z), - MKSTR(gemm), - transa, - transb, - m, - n, - k, - lda, - ldb, - ldc - ); - - /* Map BLAS chars to their corresponding BLIS enumerated type value. */ - bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); - bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); - - /* Typecast BLAS integers to BLIS integers. */ - bli_convert_blas_dim1( *m, m0 ); - bli_convert_blas_dim1( *n, n0 ); - bli_convert_blas_dim1( *k, k0 ); - - /* Set the row and column strides of the matrix operands. */ - const inc_t rs_a = 1; - const inc_t cs_a = *lda; - const inc_t rs_b = 1; - const inc_t cs_b = *ldb; - const inc_t rs_c = 1; - const inc_t cs_c = *ldc; - - const num_t dt = BLIS_DCOMPLEX; - const num_t dt_a = BLIS_DOUBLE; - - obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t ao = BLIS_OBJECT_INITIALIZER; - obj_t bo = BLIS_OBJECT_INITIALIZER; - obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; - obj_t co = BLIS_OBJECT_INITIALIZER; - - dim_t m0_a, n0_a; - dim_t m0_b, n0_b; - - bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); - bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); - - bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); - bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); - - bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); - bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); - bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); - - bli_obj_set_conjtrans( blis_transa, &ao ); - bli_obj_set_conjtrans( blis_transb, &bo ); - - // fall back on native path when zgemm is not handled in sup path. - bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - /* Finalize BLIS. */ - bli_finalize_auto(); + PASTEBLACHK(gemm) + ( + MKSTR(z), + MKSTR(gemm), + transa, + transb, + m, + n, + k, + lda, + ldb, + ldc + ); + + /* Map BLAS chars to their corresponding BLIS enumerated type value. */ + bli_param_map_netlib_to_blis_trans( *transa, &blis_transa ); + bli_param_map_netlib_to_blis_trans( *transb, &blis_transb ); + + /* Typecast BLAS integers to BLIS integers. */ + bli_convert_blas_dim1( *m, m0 ); + bli_convert_blas_dim1( *n, n0 ); + bli_convert_blas_dim1( *k, k0 ); + + /* Set the row and column strides of the matrix operands. */ + const inc_t rs_a = 1; + const inc_t cs_a = *lda; + const inc_t rs_b = 1; + const inc_t cs_b = *ldb; + const inc_t rs_c = 1; + const inc_t cs_c = *ldc; + + const num_t dt = BLIS_DCOMPLEX; + const num_t dt_a = BLIS_DOUBLE; + + obj_t alphao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t ao = BLIS_OBJECT_INITIALIZER; + obj_t bo = BLIS_OBJECT_INITIALIZER; + obj_t betao = BLIS_OBJECT_INITIALIZER_1X1; + obj_t co = BLIS_OBJECT_INITIALIZER; + + dim_t m0_a, n0_a; + dim_t m0_b, n0_b; + + bli_set_dims_with_trans( blis_transa, m0, k0, &m0_a, &n0_a ); + bli_set_dims_with_trans( blis_transb, k0, n0, &m0_b, &n0_b ); + + bli_obj_init_finish_1x1( dt, (dcomplex*)alpha, &alphao ); + bli_obj_init_finish_1x1( dt, (dcomplex*)beta, &betao ); + + bli_obj_init_finish( dt_a, m0_a, n0_a, (double*)a, rs_a, cs_a, &ao ); + bli_obj_init_finish( dt, m0_b, n0_b, (dcomplex*)b, rs_b, cs_b, &bo ); + bli_obj_init_finish( dt, m0, n0, (dcomplex*)c, rs_c, cs_c, &co ); + + bli_obj_set_conjtrans( blis_transa, &ao ); + bli_obj_set_conjtrans( blis_transb, &bo ); + + // fall back on native path when zgemm is not handled in sup path. + bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + /* Finalize BLIS. */ + bli_finalize_auto(); }// end of dzgemm_ #endif #endif diff --git a/kernels/zen/3/bli_gemm_small.c b/kernels/zen/3/bli_gemm_small.c index 18745b9c3f..0cf5c8c5ce 100644 --- a/kernels/zen/3/bli_gemm_small.c +++ b/kernels/zen/3/bli_gemm_small.c @@ -48,7 +48,7 @@ #define D_BLIS_SMALL_MATRIX_THRES (BLIS_SMALL_MATRIX_THRES / 2 ) #define D_BLIS_SMALL_M_RECT_MATRIX_THRES (BLIS_SMALL_M_RECT_MATRIX_THRES / 2) #define D_BLIS_SMALL_K_RECT_MATRIX_THRES (BLIS_SMALL_K_RECT_MATRIX_THRES / 2) -#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. +#define BLIS_ATBN_M_THRES 40 // Threshold value of M for/below which small matrix code is called. #define AT_MR 4 // The kernel dimension of the A transpose GEMM kernel.(AT_MR * NR). static err_t bli_sgemm_small ( @@ -71,7 +71,7 @@ err_t bli_dgemm_small cntx_t* cntx, cntl_t* cntl ); -static err_t bli_zgemm_small +err_t bli_zgemm_small ( obj_t* alpha, obj_t* a, @@ -81,7 +81,7 @@ static err_t bli_zgemm_small cntx_t* cntx, cntl_t* cntl ); -static err_t bli_zgemm_small_At +err_t bli_zgemm_small_At ( obj_t* alpha, obj_t* a, @@ -128,18 +128,18 @@ err_t bli_gemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + #ifdef BLIS_ENABLE_MULTITHREADING - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); - return BLIS_NOT_YET_IMPLEMENTED; + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + return BLIS_NOT_YET_IMPLEMENTED; #else // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. if (bli_cpuid_is_avx_supported() == FALSE) - { - return BLIS_NOT_YET_IMPLEMENTED; - } + { + return BLIS_NOT_YET_IMPLEMENTED; + } #endif // If alpha is zero, scale by beta and return. @@ -172,8 +172,8 @@ err_t bli_gemm_small return bli_dgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif } - if(dt == BLIS_DCOMPLEX) - { + if(dt == BLIS_DCOMPLEX) + { #ifndef BLIS_ENABLE_MULTITHREADING // bli_zgemm_small_At is called directly from blas interface for // sizes within thresholds. @@ -181,9 +181,9 @@ err_t bli_gemm_small // and directing to native implementation. return BLIS_NOT_YET_IMPLEMENTED; #else - return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); + return bli_zgemm_small_At(alpha, a, b, beta, c, cntx, cntl); #endif - } + } if (bli_obj_has_notrans( b )) { @@ -230,7 +230,7 @@ err_t bli_gemm_small return bli_sgemm_small(alpha, a, b, beta, c, cntx, cntl); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_NOT_YET_IMPLEMENTED; }; @@ -245,13 +245,13 @@ static err_t bli_sgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_TRACE_7); gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; - // when N is equal to 1 call GEMV instead of GEMM + // when N is equal to 1 call GEMV instead of GEMM if (N == 1) { bli_gemv @@ -262,7 +262,7 @@ static err_t bli_sgemm_small beta, c ); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } @@ -288,7 +288,7 @@ static err_t bli_sgemm_small dim_t tb_inc_row = 1; // row stride of matrix B dim_t tb_inc_col = ldb; // column stride of matrix B - __m256 ymm4, ymm5, ymm6, ymm7; + __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10, ymm11; __m256 ymm12, ymm13, ymm14, ymm15; __m256 ymm0, ymm1, ymm2, ymm3; @@ -302,7 +302,7 @@ static err_t bli_sgemm_small const num_t dt_exec = bli_obj_dt( c ); float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); - float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); /*Beta Zero Check*/ bool is_beta_non_zero=0; @@ -310,7 +310,7 @@ static err_t bli_sgemm_small is_beta_non_zero = 1; } - //update the pointer math if matrix B needs to be transposed. + //update the pointer math if matrix B needs to be transposed. if (bli_obj_has_trans( b )) { tb_inc_col = 1; //switch row and column strides tb_inc_row = ldb; @@ -339,11 +339,11 @@ static err_t bli_sgemm_small bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initialization + // We will use the same size to avoid pool re-initialization siz_t buffer_size = bli_pool_block_size(bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); - // Based on the available memory in the buffer we will decide if + // Based on the available memory in the buffer we will decide if // we want to do packing or not. // // This kernel assumes that "A" will be un-packged if N <= 3. @@ -355,18 +355,18 @@ static err_t bli_sgemm_small // If this check is removed it will result in the crash as // reported in CPUPL-587. // - + if ((N <= 3) || (((MR * K) << 2) > buffer_size)) { required_packing_A = 0; } - else + else { #ifdef BLIS_ENABLE_MEM_TRACING printf( "bli_sgemm_small: Requesting mem pool block of size %lu\n", buffer_size); #endif // Get the buffer from the pool, if there is no pool with - // required size, it will be created. + // required size, it will be created. bli_membrk_acquire_m(&rntm, buffer_size, BLIS_BITVAL_BUFFER_FOR_A_BLOCK, @@ -1668,7 +1668,7 @@ static err_t bli_sgemm_small if(is_beta_non_zero){ ymm7 = _mm256_fmadd_ps(ymm2, ymm1, ymm7); } - _mm256_storeu_ps(f_temp, ymm7); + _mm256_storeu_ps(f_temp, ymm7); for (int i = 0; i < m_remainder; i++) { tC[i] = f_temp[i]; @@ -1770,18 +1770,18 @@ static err_t bli_sgemm_small bli_membrk_release(&rntm, &local_mem_buf_A_s); } - - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); + + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_7); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } }; @@ -1796,22 +1796,25 @@ static err_t bli_sgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . gint_t L = M * N; /* if (N<3) //Implemenation assumes that N is atleast 3. VK */ - /* { */ - /* AOCL_DTL_TRACE_EXIT_ERR( */ - /* AOCL_DTL_LEVEL_INFO, */ + /* { */ + /* AOCL_DTL_TRACE_EXIT_ERR( */ + /* AOCL_DTL_LEVEL_INFO, */ /* "N < 3 cannot be processed by small_gemm" */ - /* ); */ + /* ); */ /* return BLIS_NOT_YET_IMPLEMENTED; VK */ - /* } */ - + /* } */ + if(L && K ) // Non-zero dimensions will be handled by either sup or native kernels { @@ -1884,7 +1887,7 @@ static err_t bli_sgemm_small bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton + // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); @@ -1900,12 +1903,12 @@ static err_t bli_sgemm_small // reported in CPUPL-587. // - // if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) - if ((N < 3) || ((D_MR * K) << 3) > buffer_size) + // if ((N <= 3) || ((D_MR * K) << 3) > buffer_size) + if ((N < 3) || ((D_MR * K) << 3) > buffer_size) { required_packing_A = 0; } - + if (required_packing_A == 1) { #ifdef BLIS_ENABLE_MEM_TRACING @@ -3359,17 +3362,17 @@ static err_t bli_sgemm_small bli_membrk_release(&rntm, &local_mem_buf_A_s); } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - } + } }; static err_t bli_sgemm_small_atbn @@ -3383,9 +3386,9 @@ static err_t bli_sgemm_small_atbn cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - gint_t M = bli_obj_length( c ); // number of rows of Matrix C + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -3410,7 +3413,7 @@ static err_t bli_sgemm_small_atbn float scratch[8] = {0.0}; const num_t dt_exec = bli_obj_dt( c ); float* restrict alpha_cast = bli_obj_buffer_for_1x1( dt_exec, alpha ); - float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); + float* restrict beta_cast = bli_obj_buffer_for_1x1( dt_exec, beta ); /*Beta Zero Check*/ bool is_beta_non_zero=0; @@ -3836,17 +3839,17 @@ static err_t bli_sgemm_small_atbn } } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); return BLIS_NONCONFORMAL_DIMENSIONS; - } + } } static err_t bli_dgemm_small_atbn @@ -3860,8 +3863,8 @@ static err_t bli_dgemm_small_atbn cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_length( b ); // number of rows of Matrix B @@ -4276,17 +4279,17 @@ static err_t bli_dgemm_small_atbn } } } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); return BLIS_SUCCESS; } else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } } err_t bli_dgemm_small_At @@ -4302,7 +4305,10 @@ err_t bli_dgemm_small_At { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } gint_t M = bli_obj_length( c ); // number of rows of Matrix C gint_t N = bli_obj_width( c ); // number of columns of Matrix C gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A), will be updated if OP(A) is Transpose(A) . @@ -4352,14 +4358,14 @@ err_t bli_dgemm_small_At if( bli_obj_has_trans( b ) ) { - tb_inc_col = 1; // switch row and column strides + tb_inc_col = 1; // switch row and column strides tb_inc_row = ldb; } __m256d ymm4, ymm5, ymm6, ymm7; __m256d ymm8, ymm9, ymm10, ymm11; __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm0, ymm1, ymm2, ymm3; double result; double scratch[8] = {0.0}; @@ -4397,7 +4403,7 @@ err_t bli_dgemm_small_At bli_membrk_rntm_set_membrk( &rntm ); // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton + // We will use the same size to avoid pool re-initliazaton siz_t buffer_size = bli_pool_block_size( bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); @@ -5780,7 +5786,7 @@ err_t bli_dgemm_small_At -static err_t bli_zgemm_small +err_t bli_zgemm_small ( obj_t* alpha, obj_t* a, @@ -5791,7635 +5797,7640 @@ static err_t bli_zgemm_small cntl_t* cntl ) { - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - bool conjtransa = bli_obj_has_conj(a); - bool conjtransb = bli_obj_has_conj(b); - - gint_t M = bli_obj_length( c ); // number of rows of Matrix C - gint_t N = bli_obj_width( c ); // number of columns of Matrix C - // number of columns of OP(A), will be updated if OP(A) is Transpose(A) - gint_t K = bli_obj_width( a ); - gint_t L = M * N; - - if(L && K ) - { - guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). - guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). - guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C - guint_t row_idx, col_idx, k; - dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A - dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B - dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C - - dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; - dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer - guint_t row_idx_packed; //packed A memory row index - guint_t lda_packed; //lda of packed A - guint_t col_idx_start; //starting index after A matrix is packed. - dim_t tb_inc_row = 1; // row stride of matrix B - dim_t tb_inc_col = ldb; // column stride of matrix B - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; - __m256d ymm0, ymm1, ymm2, ymm3; - - gint_t n_remainder; // If the N is non multiple of 3.(N%3) - gint_t m_remainder; // If the M is non multiple of 4.(M%4) - - dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); - beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s; - dcomplex *D_A_pack = NULL; - rntm_t rntm; - - //update the pointer math if matrix B needs to be transposed. - if (bli_obj_has_trans( b )) - { - tb_inc_col = 1; //switch row and column strides - tb_inc_row = ldb; - } - - //checking whether beta value is zero. - //if true, we should perform C=alpha * A*B operation - //instead of C = beta * C + alpha * (A * B) - bool is_beta_non_zero = 0; - if(!bli_obj_equals(beta, &BLIS_ZERO)) - is_beta_non_zero = 1; - - /* - * This function was using global array to pack part of A input when - * needed. However, using this global array make the function - * non-reentrant. Instead of using a global array we should allocate - * buffer for each invocation. Since the buffer size is too big or stack - * and doing malloc every time will be too expensive, better approach is - * to get the buffer from the pre-allocated pool and it the pool once we - * are doing. - * - * In order to get the buffer from pool, we need access to memory broker, - * currently this function is not invoked in such a way that it can - * receive the memory broker (via rntm). Following hack will get the - * global memory broker that can be use it to access the pool. - * - * Note there will be memory allocation at least on first innovation - * as there will not be any pool created for this size. - * Subsequent invocations will just reuse the buffer from the pool. - */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - // - // This kernel assumes that "A" will be unpackged if N <= 3. - // Usually this range (N <= 3) is handled by SUP, however, - // if SUP is disabled or for any other condition if we do - // enter this kernel with N <= 3, we want to make sure that - // "A" remains unpacked. - // - - if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) - { - required_packing_A = 0; - } - - if (required_packing_A == 1) - { -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", - buffer_size); -#endif - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - } - - /* - * The computation loop runs for Z_MRxN columns of C matrix, thus - * accessing the Z_MRxK A matrix data and KxNR B matrix data. - * The computation is organized as inner loops of dimension Z_MRxNR. - */ - // Process D_MR rows of C matrix at a time. - for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) - { - col_idx_start = 0; - tA_packed = A; - row_idx_packed = row_idx; - lda_packed = lda; - - /** - * This is the part of the pack and compute optimization. - * During the first column iteration, we store the accessed A - * matrix into contiguous static memory. This helps to keep te A - * matrix in Cache and aviods the TLB misses. - */ - if (required_packing_A) - { - col_idx = 0; - - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - tA_packed = D_A_pack; + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); -#ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); -#endif - // clear scratch registers. - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B - // matrix i data and multiplies it with - // the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *) - (tA_packed + 2), ymm1); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) * - 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *)(tA_packed + 2) - , ymm1); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *)(tA_packed + 2) - , ymm1); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - _mm256_storeu_pd( - (double *)tA_packed, ymm0); - _mm256_storeu_pd( - (double *)(tA_packed + 2) - , ymm1); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - tA_packed += Z_MR; - } - } - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - (&beta_cast->imag)); - - - BLIS_SET_YMM_REG_ZEROS - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - // col 2 - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - // col 3 - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2))); - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2) + 2)); - ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - _mm256_storeu_pd((double *)(tC + 2), ymm13); - - // modify the pointer arithematic to use packed A matrix. - col_idx_start = NR; - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = Z_MR; - } - // Process NR columns of C matrix at a time. - for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; - col_idx += NR) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + // number of columns of OP(A), will be updated if OP(A) is Transpose(A) + gint_t K = bli_obj_width( a ); + gint_t L = M * N; -#ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); -#endif - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd( - (double const *)tA); - ymm1 = _mm256_loadu_pd( - (double const *)(tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K The inner loop broadcasts - // the B matrix data and multiplies it - // with the A matrix. This loop is - // processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K The inner loop broadcasts the - // B matrix data and multiplies it with - // the A matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - BLIS_SET_YMM_REG_ZEROS - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2 + 2)); - ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - _mm256_storeu_pd((double *)(tC + 2), ymm13); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - - tptr += (tb_inc_row * 2); - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - - tptr += (tb_inc_row * 2); - tA += lda; - } - - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)(tC + 0), ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - } - - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and multiplies it with the A - // matrix. This loop is processing - // Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += (tb_inc_row * 2); - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += (tb_inc_row * 2); - tA += lda; - } - } - else //handles non-transpose case - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - } - } - m_remainder = M - row_idx; - - if ((m_remainder == 3)) - { - m_remainder -= 3; - __m128d xmm0; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - - BLIS_SET_ALL_YMM_REG_ZEROS - - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *)(tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *)(tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc * 2 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - xmm0 = _mm256_extractf128_pd(ymm13, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr - + tb_inc_col - * 0 + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - } - if ((m_remainder == 2)) - { - m_remainder -= 2; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - - - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing Z_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - BLIS_SET_YMM_REG_ZEROS - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - _mm256_storeu_pd((double *)tC, ymm8); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - _mm256_storeu_pd((double *)tC, ymm8); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - _mm256_storeu_pd((double *)tC, ymm8); - } - } - if ((m_remainder == 1)) - { - m_remainder -= 1; - __m128d xmm0; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm10, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = A + row_idx; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda; - } - - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - } - } - // Return the buffer to pool - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_zgemm_small(): releasing mem pool block\n" ); -#endif - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return BLIS_SUCCESS; - } - else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for small gemm." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } -}; + if(L && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A). + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B). + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C -static err_t bli_zgemm_small_At - ( - obj_t* alpha, - obj_t* a, - obj_t* b, - obj_t* beta, - obj_t* c, - cntx_t* cntx, - cntl_t* cntl - ) -{ - AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); - - bool conjtransa = bli_obj_has_conj(a); - bool conjtransb = bli_obj_has_conj(b); - - gint_t M = bli_obj_length( c ); // number of rows of Matrix C - gint_t N = bli_obj_width( c ); // number of columns of Matrix C - gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) - - - if (N<3) //Implemenation assumes that N is atleast 3. - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "N < 3, cannot be processed by small gemm" - ); - return BLIS_NOT_YET_IMPLEMENTED; - } - - if( M && N && K ) - { - guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) - guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) - guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C - guint_t row_idx, col_idx, k; - dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A - dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B - dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C - - dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; - dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer - guint_t row_idx_packed; //packed A memory row index - guint_t lda_packed; //lda of packed A - dim_t tb_inc_row = 1; // row stride of matrix B - dim_t tb_inc_col = ldb; // column stride of matrix B - - dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples - alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); - beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s; - dcomplex *D_A_pack = NULL; - rntm_t rntm; - - if( bli_obj_has_trans( b ) ) - { - tb_inc_col = 1; // switch row and column strides - tb_inc_row = ldb; - } - - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; - __m256d ymm0, ymm1, ymm2, ymm3; - - gint_t n_remainder; // If the N is non multiple of 3.(N%3) - gint_t m_remainder; // If the M is non multiple of 16.(M%16) - - //checking whether beta value is zero. - //if true, we should perform C=alpha * A*B operation - //instead of C = beta * C + alpha * (A * B) - bool is_beta_non_zero = 0; - if(!bli_obj_equals(beta, &BLIS_ZERO)) - is_beta_non_zero = 1; - - /* - * This function was using global array to pack part of A input when - * needed. - * However, using this global array make the function non-reentrant. - * Instead of using a global array we should allocate buffer for each - * invocation. - * Since the buffer size is too big or stack and doing malloc every time - * will be too expensive, - * better approach is to get the buffer from the pre-allocated pool and - * return - * it the pool once we are doing. - * - * In order to get the buffer from pool, we need access to memory broker, - * currently this function is not invoked in such a way that it can - * receive - * the memory broker (via rntm). Following hack will get the global memory - * broker that can be use it to access the pool. - * - * Note there will be memory allocation at least on first innovation - * as there will not be any pool created for this size. - * Subsequent invocations will just reuse the buffer from the pool. - */ - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - // Get the current size of the buffer pool for A block packing. - // We will use the same size to avoid pool re-initliazaton - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - // - // This kernel assumes that "A" will be unpackged if N <= 3. - // Usually this range (N <= 3) is handled by SUP, however, - // if SUP is disabled or for any other condition if we do - // enter this kernel with N <= 3, we want to make sure that - // "A" remains unpacked. - // - // If this check is removed it will result in the crash as - // reported in CPUPL-587. - // - - if ((N < 3) || ((Z_MR * K) << 3) > buffer_size) - { - required_packing_A = 0; - return BLIS_NOT_YET_IMPLEMENTED; - } - - if (required_packing_A == 1) - { + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; //temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + guint_t col_idx_start; //starting index after A matrix is packed. + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 4.(M%4) + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + //update the pointer math if matrix B needs to be transposed. + if (bli_obj_has_trans( b )) + { + tb_inc_col = 1; //switch row and column strides + tb_inc_row = ldb; + } + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. However, using this global array make the function + * non-reentrant. Instead of using a global array we should allocate + * buffer for each invocation. Since the buffer size is too big or stack + * and doing malloc every time will be too expensive, better approach is + * to get the buffer from the pre-allocated pool and it the pool once we + * are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive the memory broker (via rntm). Following hack will get the + * global memory broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + + if ((N < 3) || ((Z_MR * K) << 4) > buffer_size) + { + required_packing_A = 0; + } + + if (required_packing_A == 1) + { #ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemm_small: Requesting mem pool block of size %lu\n", - buffer_size); + printf( "bli_zgemm_small: Requesting mem pool block of size %lu\n", + buffer_size); #endif - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - } - - /* - * The computation loop runs for D_MRxN columns of C matrix, thus - * accessing the D_MRxK A matrix data and KxNR B matrix data. - * The computation is organized as inner loops of dimension D_MRxNR. - */ - // Process D_MR rows of C matrix at a time. - for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) - { - - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = Z_MR; - - // Pack 16xk of matrix A into buffer - // continuous access for A and strided stores to B - for(inc_t x = 0; (x) < 2; x += 1) - { - dcomplex* tA_temp = tA; - - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - ymm2 = _mm256_loadu_pd((double const *) - (tA_temp + 1 * lda)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); - - _mm256_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - ymm6); - _mm256_storeu_pd((double *) - (tA_packed + 1 * lda_packed), - ymm7); - - tA_temp += 2; - tA_packed += 2 * lda_packed; - } - - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - tA_packed[1].real = tA_temp[1 * lda].real; - tA_packed[1].imag = tA_temp[1 * lda].imag; - - tA_temp += 1; - tA_packed += lda_packed; - } - - tA += 2 * lda; - tA_packed = D_A_pack + (x + 1)*2; - } - - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = Z_MR; - - // Process NR columns of C matrix at a time. - for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for Z_MRxN columns of C matrix, thus + * accessing the Z_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension Z_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + col_idx_start = 0; + tA_packed = A; + row_idx_packed = row_idx; + lda_packed = lda; + + /** + * This is the part of the pack and compute optimization. + * During the first column iteration, we store the accessed A + * matrix into contiguous static memory. This helps to keep te A + * matrix in Cache and aviods the TLB misses. + */ + if (required_packing_A) + { + col_idx = 0; + + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + tA_packed = D_A_pack; #ifdef BLIS_ENABLE_PREFETCH - _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); - _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); #endif - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - (&beta_cast->imag)); - - - - BLIS_SET_YMM_REG_ZEROS - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - // col 2 - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - // col 3 - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2))); - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + (ldc * 2) + 2)); - ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - _mm256_storeu_pd((double *)(tC + 2), ymm13); - - } - n_remainder = N - col_idx; - - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const*)tA); - ymm1 = _mm256_loadu_pd((double const*) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc + 2)); - ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)(tC + 0), ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - _mm256_storeu_pd((double *)(tC + 2), ymm12); - } - // if the N is not multiple of 3. - // handling edge case. - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - BLIS_SET_ALL_YMM_REG_ZEROS - double *tptr = (double *)tB; - - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - tptr += tb_inc_row*2; - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm1 = _mm256_loadu_pd((double const *) - (tA + 2)); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - - BLIS_SET_YMM_REG_ZEROS - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); - ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - _mm256_storeu_pd((double *)(tC + 2), ymm11); - } - } - - m_remainder = M - row_idx; - if ((m_remainder == 3)) - { - m_remainder -= 3; - __m128d xmm0; - - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = 3; - { - dcomplex* tA_temp = tA; - - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - ymm2 = _mm256_loadu_pd((double const *) - (tA_temp + 1 * lda)); - ymm3 = _mm256_loadu_pd((double const *) - (tA_temp + 2 * lda)); - - ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); - - _mm256_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - ymm6); - xmm0 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *) - (tA_packed + 0 * lda_packed + 2), - xmm0); - - _mm256_storeu_pd((double *) - (tA_packed + 1 * lda_packed), - ymm7); - xmm0 = _mm256_extractf128_pd(ymm3, 1); - _mm_storeu_pd((double *) - (tA_packed + 1 * lda_packed + 2), - xmm0); - - tA_temp += 2; - tA_packed += 2 * lda_packed; - } - - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - tA_packed[1].real = tA_temp[1 * lda].real; - tA_packed[1].imag = tA_temp[1 * lda].imag; - tA_packed[2].real = tA_temp[2 * lda].real; - tA_packed[2].imag = tA_temp[2 * lda].imag; - - tA_temp += 1; - tA_packed += lda_packed; - } - } - - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = 3; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - ymm13 = _mm256_addsub_pd(ymm13, ymm15); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm13, ymm0); - ymm14 = _mm256_mul_pd(ymm13, ymm14); - ymm13 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc * 2 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); - ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - ymm21 = _mm256_permute_pd(ymm21, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - ymm20 = _mm256_addsub_pd(ymm20, ymm21); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - ymm10 = _mm256_add_pd(ymm10, ymm18); - ymm13 = _mm256_add_pd(ymm13, ymm20); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - xmm0 = _mm256_extractf128_pd(ymm13, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col - * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd((tptr + - tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm12 = _mm256_addsub_pd(ymm12, ymm7); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm12, ymm0); - ymm14 = _mm256_mul_pd(ymm12, ymm14); - ymm12 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); - ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm17 = _mm256_permute_pd(ymm17, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm16 = _mm256_addsub_pd(ymm16, ymm17); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm12 = _mm256_add_pd(ymm12, ymm16); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - xmm0 = _mm256_extractf128_pd(ymm12, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - ymm1 = _mm256_mul_pd(ymm1, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - xmm0 = _mm_loadu_pd((double const *) - (tA + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm11 = _mm256_addsub_pd(ymm11, ymm5); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm11, ymm0); - ymm14 = _mm256_mul_pd(ymm11, ymm14); - ymm11 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - xmm0 = _mm_loadu_pd((double const *)(tC + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); - ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm7 = _mm256_permute_pd(ymm7, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm6 = _mm256_addsub_pd(ymm6, ymm7); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm11 = _mm256_add_pd(ymm11, ymm6); - - _mm256_storeu_pd((double *)tC, ymm8); - xmm0 = _mm256_extractf128_pd(ymm11, 0); - _mm_storeu_pd((double *)(tC + 2), xmm0); - } - } - if ((m_remainder == 2)) - { - m_remainder -= 2; + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = 2; + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B + // matrix i data and multiplies it with + // the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *) + (tA_packed + 2), ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) * + 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } - { - dcomplex* tA_temp = tA; + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - ymm2 = _mm256_loadu_pd((double const *) - (tA_temp + 1 * lda)); + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + _mm256_storeu_pd( + (double *)tA_packed, ymm0); + _mm256_storeu_pd( + (double *)(tA_packed + 2) + , ymm1); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + tA_packed += Z_MR; + } + } - ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); - _mm256_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - ymm6); - _mm256_storeu_pd((double *) - (tA_packed + 1 * lda_packed), - ymm7); + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); - tA_temp += 2; - tA_packed += 2 * lda_packed; - } + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + BLIS_SET_YMM_REG_ZEROS - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - tA_packed[1].real = tA_temp[1 * lda].real; - tA_packed[1].imag = tA_temp[1 * lda].imag; + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - tA_temp += 1; - tA_packed += lda_packed; - } - } + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = 2; + // col 2 + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + tC += ldc; - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - BLIS_SET_YMM_REG_ZEROS - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - ymm0 = _mm256_loadu_pd((double const *) - (tC + ldc * 2)); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - _mm256_storeu_pd((double *)tC, ymm8); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm9); - - tC += ldc; - - _mm256_storeu_pd((double *)tC, ymm10); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - BLIS_SET_YMM_REG_ZEROS - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - _mm256_storeu_pd((double *)tC, ymm8); - tC += ldc; - _mm256_storeu_pd((double *)tC, ymm9); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matri - // x data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matri - // x data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - ymm0 = _mm256_loadu_pd((double const *)tA); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - ymm0 = _mm256_loadu_pd((double const *)tC); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - _mm256_storeu_pd((double *)tC, ymm8); - } - } - if ((m_remainder == 1)) - { - m_remainder -= 1; - __m128d xmm0; - - tA = A + row_idx * lda; - tA_packed = D_A_pack; - lda_packed = 1; - - { - dcomplex* tA_temp = tA; - - for(k = 0; (k+1) < K; k += 2) - { - ymm0 = _mm256_loadu_pd((double const *) - (tA_temp + 0 * lda)); - - xmm0 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *) - (tA_packed + 0 * lda_packed), - xmm0); - - xmm0 = _mm256_extractf128_pd(ymm0, 1); - _mm_storeu_pd((double *)(tA_packed + 1 - * lda_packed), xmm0); - - tA_temp += 2; - tA_packed += 2 * lda_packed; - } - - for(; k < K; k += 1) - { - tA_packed[0].real = tA_temp[0 * lda].real; - tA_packed[0].imag = tA_temp[0 * lda].imag; - - tA_temp += 1; - tA_packed += lda_packed; - } - } - - tA_packed = D_A_pack; - row_idx_packed = 0; - lda_packed = 1; - - for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - // This loop is processing D_MR x K - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + (tb_inc_col*2) - * 2 + 1)); - - ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); - ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); - - tptr += (tb_inc_row * 2); - tB += tb_inc_row; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - ymm14 = _mm256_permute_pd(ymm14, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - ymm10 = _mm256_addsub_pd(ymm10, ymm14); - // alpha, beta multiplication. - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm10, ymm14); - ymm10 = _mm256_hsub_pd(ymm15, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - - xmm0 = _mm_loadu_pd((double const *) - (tC + ldc * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); - ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); - - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - ymm19 = _mm256_permute_pd(ymm19, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - ymm18 = _mm256_addsub_pd(ymm18, ymm19); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - ymm10 = _mm256_add_pd(ymm10, ymm18); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm10, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - n_remainder = N - col_idx; - if (n_remainder == 2) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 2 - + 1)); - - ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); - ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - ymm6 = _mm256_permute_pd(ymm6, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - ymm9 = _mm256_addsub_pd(ymm9, ymm6); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm9, ymm0); - ymm14 = _mm256_mul_pd(ymm9, ymm14); - ymm9 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - - xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); - ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - ymm15 = _mm256_permute_pd(ymm15, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - ymm14 = _mm256_addsub_pd(ymm14, ymm15); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - ymm9 = _mm256_add_pd(ymm9, ymm14); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - tC += ldc; - xmm0 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)tC, xmm0); - } - if (n_remainder == 1) - { - //pointer math to point to proper memory - tC = C + ldc * col_idx + row_idx; - tB = B + tb_inc_col * col_idx; - tA = tA_packed + row_idx_packed; - - // clear scratch registers. - - BLIS_SET_ALL_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - double *tptr = (double *)tB; - if(conjtransa && conjtransb) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransa) - { - ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - ymm0 = _mm256_mul_pd(ymm0, ymm20); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else if(conjtransb) - { - ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matrix - // data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - ymm3 = _mm256_mul_pd(ymm3, ymm21); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - else - { - for (k = 0; k < K; ++k) - { - // The inner loop broadcasts the B matri - // x data and - // multiplies it with the A matrix. - ymm2 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0)); - ymm3 = _mm256_broadcast_sd( - (double const *) - (tptr + tb_inc_col * 0 - + 1)); - - //broadcasted matrix B elements are - //multiplied - //with matrix A columns. - xmm0 = _mm_loadu_pd((double const *)(tA)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); - ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); - - tptr += tb_inc_row*2; - tA += lda_packed; - } - } - ymm4 = _mm256_permute_pd(ymm4, 0x5); - - ymm8 = _mm256_addsub_pd(ymm8, ymm4); - - // alpha, beta multiplication. - ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); - ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - - ymm14 = _mm256_permute_pd(ymm0, 0x5); - ymm14 = _mm256_mul_pd(ymm14, ymm1); - ymm15 = _mm256_mul_pd(ymm8, ymm0); - ymm14 = _mm256_mul_pd(ymm8, ymm14); - ymm8 = _mm256_hsub_pd(ymm15, ymm14); - - - BLIS_SET_YMM_REG_ZEROS - xmm0 = _mm_setzero_pd(); - - ymm2 = _mm256_broadcast_sd((double const *) - &beta_cast->real); - ymm3 = _mm256_broadcast_sd((double const *) - &beta_cast->imag); - - if(is_beta_non_zero) - { - // multiply C by beta and accumulate col 1. - xmm0 = _mm_loadu_pd((double const *)(tC)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); - - ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); - ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); - } - ymm5 = _mm256_permute_pd(ymm5, 0x5); - - ymm4 = _mm256_addsub_pd(ymm4, ymm5); - - ymm8 = _mm256_add_pd(ymm8, ymm4); - - xmm0 = _mm256_extractf128_pd(ymm8, 0); - _mm_storeu_pd((double *)tC, xmm0); - - } - } - // Return the buffer to pool - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ -#ifdef BLIS_ENABLE_MEM_TRACING - printf( "bli_dgemm_small_At(): releasing mem pool block\n" ); -#endif - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - return BLIS_SUCCESS; - } - else - { - AOCL_DTL_TRACE_EXIT_ERR( - AOCL_DTL_LEVEL_INFO, - "Invalid dimesions for dgemm_small_At." - ); - return BLIS_NONCONFORMAL_DIMENSIONS; - } + // modify the pointer arithematic to use packed A matrix. + col_idx_start = NR; + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + } + // Process NR columns of C matrix at a time. + for (col_idx = col_idx_start; (col_idx + (NR - 1)) < N; + col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd( + (double const *)tA); + ymm1 = _mm256_loadu_pd( + (double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts + // the B matrix data and multiplies it + // with the A matrix. This loop is + // processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K The inner loop broadcasts the + // B matrix data and multiplies it with + // the A matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + + tptr += (tb_inc_row * 2); + tA += lda; + } + + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and multiplies it with the A + // matrix. This loop is processing + // Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += (tb_inc_row * 2); + tA += lda; + } + } + else //handles non-transpose case + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + m_remainder = M - row_idx; + + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *)(tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0 + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + xmm0 = _mm_loadu_pd((double const *)(tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing Z_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = A + row_idx; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda; + } + + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for small gemm." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } +}; + +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ) +{ + AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO); + if (bli_cpuid_is_avx_supported() == FALSE) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + bool conjtransa = bli_obj_has_conj(a); + bool conjtransb = bli_obj_has_conj(b); + + gint_t M = bli_obj_length( c ); // number of rows of Matrix C + gint_t N = bli_obj_width( c ); // number of columns of Matrix C + gint_t K = bli_obj_width_after_trans( a ); // number of columns of OP(A) + + if (N<3) //Implemenation assumes that N is atleast 3. + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "N < 3, cannot be processed by small gemm" + ); + return BLIS_NOT_YET_IMPLEMENTED; + } + + if( M && N && K ) + { + guint_t lda = bli_obj_col_stride( a ); // column stride of matrix OP(A) + guint_t ldb = bli_obj_col_stride( b ); // column stride of matrix OP(B) + guint_t ldc = bli_obj_col_stride( c ); // column stride of matrix C + guint_t row_idx, col_idx, k; + dcomplex *A = bli_obj_buffer_at_off(a); //pointer to elements of Matrix A + dcomplex *B = bli_obj_buffer_at_off(b); //pointer to elements of Matrix B + dcomplex *C = bli_obj_buffer_at_off(c); //pointer to elements of Matrix C + + dcomplex *tA = A, *tB = B, *tC = C;//, *tA_pack; + dcomplex *tA_packed; // temprorary pointer to hold packed A memory pointer + guint_t row_idx_packed; //packed A memory row index + guint_t lda_packed; //lda of packed A + dim_t tb_inc_row = 1; // row stride of matrix B + dim_t tb_inc_col = ldb; // column stride of matrix B + + dcomplex *alpha_cast, *beta_cast; // alpha, beta multiples + alpha_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, alpha); + beta_cast = bli_obj_buffer_for_1x1(BLIS_DCOMPLEX, beta); + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s; + dcomplex *D_A_pack = NULL; + rntm_t rntm; + + if( bli_obj_has_trans( b ) ) + { + tb_inc_col = 1; // switch row and column strides + tb_inc_row = ldb; + } + + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm0, ymm1, ymm2, ymm3; + + gint_t n_remainder; // If the N is non multiple of 3.(N%3) + gint_t m_remainder; // If the M is non multiple of 16.(M%16) + + //checking whether beta value is zero. + //if true, we should perform C=alpha * A*B operation + //instead of C = beta * C + alpha * (A * B) + bool is_beta_non_zero = 0; + if(!bli_obj_equals(beta, &BLIS_ZERO)) + is_beta_non_zero = 1; + + /* + * This function was using global array to pack part of A input when + * needed. + * However, using this global array make the function non-reentrant. + * Instead of using a global array we should allocate buffer for each + * invocation. + * Since the buffer size is too big or stack and doing malloc every time + * will be too expensive, + * better approach is to get the buffer from the pre-allocated pool and + * return + * it the pool once we are doing. + * + * In order to get the buffer from pool, we need access to memory broker, + * currently this function is not invoked in such a way that it can + * receive + * the memory broker (via rntm). Following hack will get the global memory + * broker that can be use it to access the pool. + * + * Note there will be memory allocation at least on first innovation + * as there will not be any pool created for this size. + * Subsequent invocations will just reuse the buffer from the pool. + */ + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + // Get the current size of the buffer pool for A block packing. + // We will use the same size to avoid pool re-initliazaton + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool(bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + // + // This kernel assumes that "A" will be unpackged if N <= 3. + // Usually this range (N <= 3) is handled by SUP, however, + // if SUP is disabled or for any other condition if we do + // enter this kernel with N <= 3, we want to make sure that + // "A" remains unpacked. + // + // If this check is removed it will result in the crash as + // reported in CPUPL-587. + // + + if ((N < 3) || ((Z_MR * K) << 4) > buffer_size) + { + required_packing_A = 0; + return BLIS_NOT_YET_IMPLEMENTED; + } + + if (required_packing_A == 1) + { +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small_At: Requesting mem pool block of size %lu\n", + buffer_size); +#endif + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + } + + /* + * The computation loop runs for D_MRxN columns of C matrix, thus + * accessing the D_MRxK A matrix data and KxNR B matrix data. + * The computation is organized as inner loops of dimension D_MRxNR. + */ + // Process D_MR rows of C matrix at a time. + for (row_idx = 0; (row_idx + (Z_MR - 1)) < M; row_idx += Z_MR) + { + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = Z_MR; + + // Pack 16xk of matrix A into buffer + // continuous access for A and strided stores to B + for(inc_t x = 0; (x) < 2; x += 1) + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + + tA += 2 * lda; + tA_packed = D_A_pack + (x + 1)*2; + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = Z_MR; + + // Process NR columns of C matrix at a time. + for (col_idx = 0; (col_idx + (NR - 1)) < N; col_idx += NR) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + +#ifdef BLIS_ENABLE_PREFETCH + _mm_prefetch((char*)(tC + 0), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + ldc + 8), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc), _MM_HINT_T0); + _mm_prefetch((char*)(tC + 2 * ldc + 8), _MM_HINT_T0); +#endif + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + (&beta_cast->imag)); + + + + BLIS_SET_YMM_REG_ZEROS + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + // col 2 + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + // col 3 + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2))); + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + (ldc * 2) + 2)); + ymm20 = _mm256_fmadd_pd(ymm0, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm0, ymm3, ymm21); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + _mm256_storeu_pd((double *)(tC + 2), ymm13); + + } + n_remainder = N - col_idx; + + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const*)tA); + ymm1 = _mm256_loadu_pd((double const*) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc + 2)); + ymm16 = _mm256_fmadd_pd(ymm0, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm0, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)(tC + 0), ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + _mm256_storeu_pd((double *)(tC + 2), ymm12); + } + // if the N is not multiple of 3. + // handling edge case. + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + double *tptr = (double *)tB; + + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *)(tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((double const *)(tptr + tb_inc_col * 0 + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + tptr += tb_inc_row*2; + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm1 = _mm256_loadu_pd((double const *) + (tA + 2)); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + + BLIS_SET_YMM_REG_ZEROS + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + 2)); + ymm6 = _mm256_fmadd_pd(ymm0, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm0, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + _mm256_storeu_pd((double *)(tC + 2), ymm11); + } + } + + m_remainder = M - row_idx; + if ((m_remainder == 3)) + { + m_remainder -= 3; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 3; + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + ymm3 = _mm256_loadu_pd((double const *) + (tA_temp + 2 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + xmm0 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed + 2), + xmm0); + + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + xmm0 = _mm256_extractf128_pd(ymm3, 1); + _mm_storeu_pd((double *) + (tA_packed + 1 * lda_packed + 2), + xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + tA_packed[2].real = tA_temp[2 * lda].real; + tA_packed[2].imag = tA_temp[2 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 3; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + ymm13 = _mm256_addsub_pd(ymm13, ymm15); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm13, ymm0); + ymm14 = _mm256_mul_pd(ymm13, ymm14); + ymm13 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + ymm20 = _mm256_fmadd_pd(ymm1, ymm2, ymm20); + ymm21 = _mm256_fmadd_pd(ymm1, ymm3, ymm21); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + ymm21 = _mm256_permute_pd(ymm21, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + ymm20 = _mm256_addsub_pd(ymm20, ymm21); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + ymm10 = _mm256_add_pd(ymm10, ymm18); + ymm13 = _mm256_add_pd(ymm13, ymm20); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + xmm0 = _mm256_extractf128_pd(ymm13, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col + * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd((tptr + + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm12 = _mm256_addsub_pd(ymm12, ymm7); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm12, ymm0); + ymm14 = _mm256_mul_pd(ymm12, ymm14); + ymm12 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + ymm16 = _mm256_fmadd_pd(ymm1, ymm2, ymm16); + ymm17 = _mm256_fmadd_pd(ymm1, ymm3, ymm17); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm17 = _mm256_permute_pd(ymm17, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm16 = _mm256_addsub_pd(ymm16, ymm17); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm12 = _mm256_add_pd(ymm12, ymm16); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + xmm0 = _mm256_extractf128_pd(ymm12, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + ymm1 = _mm256_mul_pd(ymm1, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + xmm0 = _mm_loadu_pd((double const *) + (tA + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm11 = _mm256_addsub_pd(ymm11, ymm5); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm11, ymm0); + ymm14 = _mm256_mul_pd(ymm11, ymm14); + ymm11 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + xmm0 = _mm_loadu_pd((double const *)(tC + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + ymm6 = _mm256_fmadd_pd(ymm1, ymm2, ymm6); + ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm7 = _mm256_permute_pd(ymm7, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm6 = _mm256_addsub_pd(ymm6, ymm7); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm11 = _mm256_add_pd(ymm11, ymm6); + + _mm256_storeu_pd((double *)tC, ymm8); + xmm0 = _mm256_extractf128_pd(ymm11, 0); + _mm_storeu_pd((double *)(tC + 2), xmm0); + } + } + if ((m_remainder == 2)) + { + m_remainder -= 2; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 2; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + ymm2 = _mm256_loadu_pd((double const *) + (tA_temp + 1 * lda)); + + ymm6 = _mm256_permute2f128_pd(ymm0,ymm2,0x20); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm2,0x31); + + _mm256_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + ymm6); + _mm256_storeu_pd((double *) + (tA_packed + 1 * lda_packed), + ymm7); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + tA_packed[1].real = tA_temp[1 * lda].real; + tA_packed[1].imag = tA_temp[1 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 2; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + ymm0 = _mm256_loadu_pd((double const *) + (tC + ldc * 2)); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + _mm256_storeu_pd((double *)tC, ymm8); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm9); + + tC += ldc; + + _mm256_storeu_pd((double *)tC, ymm10); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + ymm0 = _mm256_loadu_pd((double const *)(tC + ldc)); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + _mm256_storeu_pd((double *)tC, ymm8); + tC += ldc; + _mm256_storeu_pd((double *)tC, ymm9); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + ymm0 = _mm256_loadu_pd((double const *)tA); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + ymm0 = _mm256_loadu_pd((double const *)tC); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + _mm256_storeu_pd((double *)tC, ymm8); + } + } + if ((m_remainder == 1)) + { + m_remainder -= 1; + __m128d xmm0; + + tA = A + row_idx * lda; + tA_packed = D_A_pack; + lda_packed = 1; + + { + dcomplex* tA_temp = tA; + + for(k = 0; (k+1) < K; k += 2) + { + ymm0 = _mm256_loadu_pd((double const *) + (tA_temp + 0 * lda)); + + xmm0 = _mm256_extractf128_pd(ymm0, 0); + _mm_storeu_pd((double *) + (tA_packed + 0 * lda_packed), + xmm0); + + xmm0 = _mm256_extractf128_pd(ymm0, 1); + _mm_storeu_pd((double *)(tA_packed + 1 + * lda_packed), xmm0); + + tA_temp += 2; + tA_packed += 2 * lda_packed; + } + + for(; k < K; k += 1) + { + tA_packed[0].real = tA_temp[0 * lda].real; + tA_packed[0].imag = tA_temp[0 * lda].imag; + + tA_temp += 1; + tA_packed += lda_packed; + } + } + + tA_packed = D_A_pack; + row_idx_packed = 0; + lda_packed = 1; + + for (col_idx = 0; (col_idx + 2) < N; col_idx += 3) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + // This loop is processing D_MR x K + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + (tb_inc_col*2) + * 2 + 1)); + + ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10); + ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14); + + tptr += (tb_inc_row * 2); + tB += tb_inc_row; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + ymm14 = _mm256_permute_pd(ymm14, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + ymm10 = _mm256_addsub_pd(ymm10, ymm14); + // alpha, beta multiplication. + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm10, ymm0); + ymm14 = _mm256_mul_pd(ymm10, ymm14); + ymm10 = _mm256_hsub_pd(ymm15, ymm14); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + + xmm0 = _mm_loadu_pd((double const *) + (tC + ldc * 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm18 = _mm256_fmadd_pd(ymm0, ymm2, ymm18); + ymm19 = _mm256_fmadd_pd(ymm0, ymm3, ymm19); + + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + ymm19 = _mm256_permute_pd(ymm19, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + ymm18 = _mm256_addsub_pd(ymm18, ymm19); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + ymm10 = _mm256_add_pd(ymm10, ymm18); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm10, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + n_remainder = N - col_idx; + if (n_remainder == 2) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 2 + + 1)); + + ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9); + ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + ymm6 = _mm256_permute_pd(ymm6, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + ymm9 = _mm256_addsub_pd(ymm9, ymm6); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm9, ymm0); + ymm14 = _mm256_mul_pd(ymm9, ymm14); + ymm9 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + + xmm0 = _mm_loadu_pd((double const *)(tC + ldc)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm14 = _mm256_fmadd_pd(ymm0, ymm2, ymm14); + ymm15 = _mm256_fmadd_pd(ymm0, ymm3, ymm15); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + ymm15 = _mm256_permute_pd(ymm15, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + ymm14 = _mm256_addsub_pd(ymm14, ymm15); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + ymm9 = _mm256_add_pd(ymm9, ymm14); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + tC += ldc; + xmm0 = _mm256_extractf128_pd(ymm9, 0); + _mm_storeu_pd((double *)tC, xmm0); + } + if (n_remainder == 1) + { + //pointer math to point to proper memory + tC = C + ldc * col_idx + row_idx; + tB = B + tb_inc_col * col_idx; + tA = tA_packed + row_idx_packed; + + // clear scratch registers. + + BLIS_SET_ALL_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + double *tptr = (double *)tB; + if(conjtransa && conjtransb) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransa) + { + ymm20 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + ymm0 = _mm256_mul_pd(ymm0, ymm20); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else if(conjtransb) + { + ymm21 = _mm256_setr_pd(-1.0, -1.0, -1.0, -1.0); + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matrix + // data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + ymm3 = _mm256_mul_pd(ymm3, ymm21); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + else + { + for (k = 0; k < K; ++k) + { + // The inner loop broadcasts the B matri + // x data and + // multiplies it with the A matrix. + ymm2 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0)); + ymm3 = _mm256_broadcast_sd( + (double const *) + (tptr + tb_inc_col * 0 + + 1)); + + //broadcasted matrix B elements are + //multiplied + //with matrix A columns. + xmm0 = _mm_loadu_pd((double const *)(tA)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8); + ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4); + + tptr += tb_inc_row*2; + tA += lda_packed; + } + } + ymm4 = _mm256_permute_pd(ymm4, 0x5); + + ymm8 = _mm256_addsub_pd(ymm8, ymm4); + + // alpha, beta multiplication. + ymm0 = _mm256_broadcast_pd(( __m128d const*)alpha_cast); + ymm1 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + + ymm14 = _mm256_permute_pd(ymm0, 0x5); + ymm14 = _mm256_mul_pd(ymm14, ymm1); + ymm15 = _mm256_mul_pd(ymm8, ymm0); + ymm14 = _mm256_mul_pd(ymm8, ymm14); + ymm8 = _mm256_hsub_pd(ymm15, ymm14); + + + BLIS_SET_YMM_REG_ZEROS + xmm0 = _mm_setzero_pd(); + + ymm2 = _mm256_broadcast_sd((double const *) + &beta_cast->real); + ymm3 = _mm256_broadcast_sd((double const *) + &beta_cast->imag); + + if(is_beta_non_zero) + { + // multiply C by beta and accumulate col 1. + xmm0 = _mm_loadu_pd((double const *)(tC)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm0, 0); + + ymm4 = _mm256_fmadd_pd(ymm0, ymm2, ymm4); + ymm5 = _mm256_fmadd_pd(ymm0, ymm3, ymm5); + } + ymm5 = _mm256_permute_pd(ymm5, 0x5); + + ymm4 = _mm256_addsub_pd(ymm4, ymm5); + + ymm8 = _mm256_add_pd(ymm8, ymm4); + + xmm0 = _mm256_extractf128_pd(ymm8, 0); + _mm_storeu_pd((double *)tC, xmm0); + + } + } + // Return the buffer to pool + if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )){ +#ifdef BLIS_ENABLE_MEM_TRACING + printf( "bli_zgemm_small_At(): releasing mem pool block\n" ); +#endif + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + return BLIS_SUCCESS; + } + else + { + AOCL_DTL_TRACE_EXIT_ERR( + AOCL_DTL_LEVEL_INFO, + "Invalid dimesions for dgemm_small_At." + ); + return BLIS_NONCONFORMAL_DIMENSIONS; + } }; #endif diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 10a656835f..92ee71b2be 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -259,6 +259,28 @@ err_t bli_dgemm_small_At cntl_t* cntl ); +err_t bli_zgemm_small + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + +err_t bli_zgemm_small_At + ( + obj_t* alpha, + obj_t* a, + obj_t* b, + obj_t* beta, + obj_t* c, + cntx_t* cntx, + cntl_t* cntl + ); + // gemm square matrix size friendly implementation err_t bli_gemm_sqp ( From a73ebf1ba97b4c6431655bb6ada195360c4917d8 Mon Sep 17 00:00:00 2001 From: Sireesha Sanga Date: Wed, 6 Apr 2022 00:53:27 +0530 Subject: [PATCH 39/63] Performance Improvement for ztrsm small sizes Details: - Optimization of ztrsm for Non-unit Diag Variants. - Handled Overflow and Underflow Vulnerabilites in ztrsm small implementations. - Fixed failures observed in libflame testing. - Fine-tuned ztrsm small implementations for specific sizes 64<= m,n <= 256, by keeping the number of threads to the optimum value, under AOCL_DYNAMIC flag. - For small sizes, ztrsm small implementation is used for all variants. AMD-Internal: [SWLCSG-1194] Change-Id: I066491bb03e5cda390cb699182af4350ae60be2d --- frame/base/bli_rntm.c | 10 ++- frame/compat/bla_trsm_amd.c | 2 - kernels/zen/3/bli_trsm_small.c | 155 ++++++++++++--------------------- 3 files changed, 66 insertions(+), 101 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index c597074f58..f8e00c6208 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -639,6 +639,14 @@ void bli_nthreads_optimum( if(m<=512 && n<=512) n_threads_ideal = 4; } + else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + + if((m>=64) && (m<=256) && (n>=64) && (n<=256)) + n_threads_ideal = 8; + } else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) { dim_t n = bli_obj_length(c); diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index eb5c835ff5..9ff8073be0 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1184,7 +1184,6 @@ void ztrsm_ * is doing better than native multithread */ bool nt = bli_thread_get_is_parallel(); - if((blis_side == BLIS_RIGHT) || (blis_diaga == BLIS_UNIT_DIAG)) { if(((nt==0) && (m0<=500) && (n0<=500)) || (nt && ((m0+n0)<128))) { @@ -1206,7 +1205,6 @@ void ztrsm_ return; } } - } #endif bli_trsmnat diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 32b7647a50..bb6d198c78 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -5771,68 +5771,58 @@ BLIS_INLINE err_t ztrsm_AuXB_ref * Performs dcomplex division of vec1 and vec2 with ymm1. * vec1 and vec2 gets divided by ymm1 which holds * diagonal element from buffer. - * Function gets called while performing TRSM. + * Using bli_zinvscals() to avoid overflow and underflow + * scenarios. Function gets called while performing TRSM. */ #define BLIS_ZTRSM_TWO_DIV(vec1, vec2) {\ if(!is_unitdiag) {\ if(conjtransa){\ ymm1 = _mm256_mul_pd(ymm1, ymm0);\ }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0);\ - /*perform decomplex multiplication*/\ - /* Switch the real and imaginary elements of vec2 */\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - /* Multiply vec1 and vec2 */ \ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - /* Multiply vec1 and the modified vec2 */\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - /* Horizontally subtract the elements in vec3 and vec4 */\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - /* Negate the imaginary elements of vec2 */\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec2, ymm12);\ - ymm14 = _mm256_mul_pd(vec2, ymm14);\ - vec2 = _mm256_hsub_pd(ymm13, ymm14);\ - /*dcomplex multiplication is done*/\ - /*Swapping real & imaginary component position for addition with respective - * components*/\ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ - vec2 = _mm256_div_pd(vec2, ymm14);\ +\ + dcomplex b_data[4];\ + dcomplex d11_data[2];\ +\ + _mm256_storeu_pd((double *)(b_data), vec1);\ + _mm256_storeu_pd((double *)(b_data + 2), vec2);\ + _mm256_storeu_pd((double *)(d11_data), ymm1);\ +\ + for(dim_t i = 0; i < 4; i++)\ + {\ + bli_zinvscals(d11_data[0],b_data[i]);\ + }\ +\ + vec1 = _mm256_loadu_pd((double *)b_data);\ + vec2 = _mm256_loadu_pd((double *)(b_data+2));\ +\ }\ } /** * Performs dcomplex division of vec1 with ymm1. * ymm1 holds diagonal element from buffer. - * Function gets called while performing TRSM. + * Using bli_zinvscals() to avoid overflow and underflow + * scenarios. Function gets called while performing TRSM. */ #define BLIS_ZTRSM_DIV(vec1) {\ if(!is_unitdiag){\ if(conjtransa){\ ymm1 = _mm256_mul_pd(ymm1, ymm0);\ }\ - ymm12 = _mm256_mul_pd(ymm1, ymm0); /*vec2 and ymm8 is vec1*/\ - ymm14 = _mm256_permute_pd(ymm12, 0x5);\ - ymm14 = _mm256_mul_pd(ymm14, ymm0);\ - ymm13 = _mm256_mul_pd(vec1, ymm12); /*vec3*/\ - ymm14 = _mm256_mul_pd(vec1, ymm14); /*vec4*/\ - vec1 = _mm256_hsub_pd(ymm13, ymm14);\ - \ - ymm12 = _mm256_mul_pd(ymm1, ymm1);\ - ymm13 = _mm256_permute4x64_pd(ymm12, 0xb1);\ - ymm14 = _mm256_add_pd(ymm12, ymm13);\ - \ - /*Finally dividing numerator by denominator*/\ - vec1 = _mm256_div_pd(vec1, ymm14);\ +\ + dcomplex b_data[2];\ + dcomplex d11_data[2];\ +\ + _mm256_storeu_pd((double *)(b_data), vec1);\ + _mm256_storeu_pd((double *)(d11_data), ymm1);\ +\ + for(dim_t i = 0; i < 2; i++)\ + {\ + bli_zinvscals(d11_data[0],b_data[i]);\ + }\ +\ + vec1 = _mm256_loadu_pd((double *)b_data);\ +\ }\ } @@ -6007,7 +5997,6 @@ BLIS_INLINE void bli_ztrsm_small_pack } - BLIS_INLINE void ztrsm_small_pack_diag_element ( bool is_unitdiag, @@ -6018,64 +6007,31 @@ BLIS_INLINE void ztrsm_small_pack_diag_element ) { #ifdef BLIS_ENABLE_TRSM_PREINVERSION - __m256d ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8; - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); -#else - __m256d ymm1, ymm2, ymm3; -#endif - bool is_four = (size == 4) ? 1 : 0; - dcomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - if(!is_unitdiag) + // If Preinversion is enabled, inverse the diaganol + // elements from A and pack into diagonal buffer. + // In order to avoid the overflow and underflow scenarios, + // bli_zinvscals is used + for( dim_t i = 0; i < size; i++) { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_pd((__m128d const *)a11); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a +1); - /*Pick one element frome each column and create 3 element vector - and store it*/ - ymm1 = _mm256_permute2f128_pd(ymm1, ymm2, 0x20); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - - if(is_four) - { - ymm3 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*2 + 2); - ymm2 = _mm256_broadcast_pd((__m128d const *)a11+ cs_a*3 + 3); - ymm2 = _mm256_permute2f128_pd(ymm3, ymm2, 0x20); - } + dim_t d = ((i*cs_a) + i); + dcomplex ones = {1.0, 0.0}; + bli_zinvscals(a11[d], ones) + d11_pack[i].real = ones.real; + d11_pack[i].imag = ones.imag; + } -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - ymm5 = _mm256_mul_pd(ymm2,ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - ymm8 = _mm256_permute4x64_pd(ymm5, 0xb1); - - ymm5 = _mm256_add_pd(ymm5, ymm8); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - ymm2 = _mm256_mul_pd(ymm2, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); - ymm2 = _mm256_div_pd(ymm2, ymm5); -#endif +#else //BLIS_ENABLE_TRSM_PREINVERSION - } - _mm256_store_pd((double *)d11_pack, ymm1); - if(is_four) + // If Preinversion is disabled, pack the diaganol + // elements from A into diagonal buffer. + for( dim_t i = 0; i < size; i++) { - _mm256_store_pd((double *)(d11_pack + 2), ymm2); + dim_t d = ((i*cs_a) + i); + bli_zcopys(a11[d],d11_pack[i]); } - else - { - _mm_store_pd((double *)(d11_pack + 2), - _mm256_extractf128_pd(ymm2,0)); - } +#endif //BLIS_ENABLE_TRSM_PREINVERSION } - /*implements TRSM for the case XA = alpha * B *A is lower triangular, non-unit diagonal/unit diagonal, transpose *dimensions: X:mxn A:nxn B: mxn @@ -14948,9 +14904,12 @@ BLIS_INLINE void strsm_small_pack_diag_element __m256 ymm0, ymm1, ymm2, ymm3; __m256 ymm4, ymm5, ymm6, ymm7; __m256 ymm8, ymm9, ymm10,ymm11; - __m256 ymm14, ymm15, ymm12,ymm13; + __m256 ymm14, ymm15, ymm12; float ones = 1.0; - ymm13 = ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); + ymm14 = ymm15 = _mm256_broadcast_ss((float const *)&ones); +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + __m256 ymm13 = _mm256_broadcast_ss((float const *)&ones); +#endif if(side=='L'||side=='l') { if(!is_unitdiag) From f8f6cc6d81ba9b00c5b5e5c59d5401d592e8c3b5 Mon Sep 17 00:00:00 2001 From: Arnav Sharma Date: Mon, 21 Mar 2022 12:53:05 +0530 Subject: [PATCH 40/63] Optimized S/DCOMPLEX DOTXAXPYF using AVX2 Intrinsics Details: - Optimized implementation of DOTXAXPYF fused kernel for single and double precision complex datatype using AVX2 Intrinsics - Updated definitions zen context AMD-Internal: [CPUPL-2059] Change-Id: Ic657e4b66172ae459173626222af2756a4125565 --- config/zen/bli_cntx_init_zen.c | 5 +- config/zen2/bli_cntx_init_zen2.c | 5 +- config/zen3/bli_cntx_init_zen3.c | 5 +- kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c | 832 ++++++++++++++++++++++- kernels/zen/bli_kernels_zen.h | 2 + 5 files changed, 843 insertions(+), 6 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 1badc24f96..674549d77f 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -80,12 +80,15 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 10, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_8, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 997ccdba2e..48cb90a4f8 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -92,12 +92,15 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 10, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index 61fefdbc31..e83a12b401 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -92,12 +92,15 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1f kernels. bli_cntx_set_l1f_kers ( - 10, + 12, // axpyf BLIS_AXPYF_KER, BLIS_FLOAT, bli_saxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DOUBLE, bli_daxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_SCOMPLEX, bli_caxpyf_zen_int_5, BLIS_AXPYF_KER, BLIS_DCOMPLEX, bli_zaxpyf_zen_int_5, + // dotxaxpyf + BLIS_DOTXAXPYF_KER, BLIS_SCOMPLEX, bli_cdotxaxpyf_zen_int_8, + BLIS_DOTXAXPYF_KER, BLIS_DCOMPLEX, bli_zdotxaxpyf_zen_int_8, // dotxf BLIS_DOTXF_KER, BLIS_FLOAT, bli_sdotxf_zen_int_8, BLIS_DOTXF_KER, BLIS_DOUBLE, bli_ddotxf_zen_int_8, diff --git a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c index b24aab7571..1be9975ecf 100644 --- a/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c +++ b/kernels/zen/1f/bli_dotxaxpyf_zen_int_8.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -40,6 +40,12 @@ typedef union{ double d[4] __attribute__((aligned(64))); }vec; +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + /** * bli_pre_hemv_lower_8x8 is a helper function which computes * "y = y + alpha * a * x" @@ -467,8 +473,9 @@ void bli_ddotxaxpyf_zen_int_8 /* A is m x n. */ /* y = beta * y + alpha * A^T w; */ /* z = z + alpha * A x; */ - if ((inca == 1) && (incw == 1) && (incx == 1) - && (incy == 1) && (incz == 1) && (b_n == 8)) + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + (inca == 1) && (incw == 1) && (incx == 1) + && (incy == 1) && (incz == 1) && (b_n == 8) ) { __m256d r0, r1; r0 = _mm256_setzero_pd(); @@ -733,3 +740,822 @@ void bli_ddotxaxpyf_zen_int_8 ); } } + +/** + * zdotxaxpyf kernel performs dot and apxy function together. + * y := conj(beta) * y + conj(alpha) * conj(A)^t * conj(w) (dotxf) + * z := z + alpha * conj(A) * conj(x) (axpyf) + * where, + * A is an m x b matrix. + * w, z are vectors of length m. + * x, y are vectors of length b. + * alpha, beta are scalars + */ +void bli_zdotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + dcomplex* restrict alpha, + dcomplex* restrict a, inc_t inca, inc_t lda, + dcomplex* restrict w, inc_t incw, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict beta, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + // A: m x b + // w, z: m + // x, y: b + // + // y = beta * y + alpha * A^T w; + // z = z + alpha * A x; + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + ( inca == 1 ) && ( incw == 1 ) && ( incx == 1 ) + && ( incy == 1 ) && ( incz == 1 ) && ( b_n == 4 ) ) + { + // Temporary rho buffer holds computed dot product result + dcomplex rho[ 4 ]; + + // chi? variables to hold scaled scaler values from x vector + dcomplex chi0; + dcomplex chi1; + dcomplex chi2; + dcomplex chi3; + + // If beta is zero, clear y + // Else, scale by beta + if ( PASTEMAC(z,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,scals)( *beta, y[i] ); + } + } + + // If the vectors are empty or if alpha is zero, return early + if ( bli_zero_dim1( m ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Initialize rho vector to 0 + for ( dim_t i = 0; i < 4; ++i ) PASTEMAC(z,set0s)( rho[i] ); + + // Set conj use variable for dot operation + conj_t conjdot_use = conjw; + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjdot_use ); + } + + // Set conj use variable for dotxf operation, scalar + dim_t conjdotxf = 1; + if ( bli_is_conj( conjdot_use ) ) + { + conjdotxf = -1; + } + + // Set conj use variable for axpyf operation, scalar + dim_t conjaxpyf = 1; + if ( bli_is_conj( conja ) ) + { + conjaxpyf = -1; + } + + // Store each element of x vector in a scalar and apply conjx + if( bli_is_noconj( conjx ) ) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + bli_zcopycjs( conjx, *( x + 0*incx ), chi0 ); + bli_zcopycjs( conjx, *( x + 1*incx ), chi1 ); + bli_zcopycjs( conjx, *( x + 2*incx ), chi2 ); + bli_zcopycjs( conjx, *( x + 3*incx ), chi3 ); + } + + // Scale each chi scalar by alpha + bli_zscals( *alpha, chi0 ); + bli_zscals( *alpha, chi1 ); + bli_zscals( *alpha, chi2 ); + bli_zscals( *alpha, chi3 ); + + dim_t row = 0; + dim_t iter = m / 2; + dim_t rem = m % 2; + if (iter) + { + vec x0R, x1R, x2R, x3R; // x?R holds real part of x[?] + vec x0I, x1I, x2I, x3I; // x?I hold real part of x[?] + vec a_tile0, a_tile1; // a_tile? holds columns of a + vec temp1, temp2, temp3; // temp? registers for intermediate op + vec wR, wI; // holds real & imag components of w + vec z_vec; // holds the z vector + + // rho? registers hold results of fmadds for dotxf operation + vec rho0, rho1, rho2, rho3; + vec rho4, rho5, rho6, rho7; + + // For final computation, based on conjdot_use + // sign of imaginary component needs to be toggled + __m256d no_conju = _mm256_setr_pd( -1, 1, -1, 1 ); + __m256d conju = _mm256_setr_pd( 1, -1, 1, -1 ); + + // Clear the temp registers + temp1.v = _mm256_setzero_pd(); + temp2.v = _mm256_setzero_pd(); + temp3.v = _mm256_setzero_pd(); + + // Clear rho registers + // Once micro tile is computed, horizontal addition + // of all rho's will provide us with the result of + // dotxf opereation + rho0.v = _mm256_setzero_pd(); + rho1.v = _mm256_setzero_pd(); + rho2.v = _mm256_setzero_pd(); + rho3.v = _mm256_setzero_pd(); + rho4.v = _mm256_setzero_pd(); + rho5.v = _mm256_setzero_pd(); + rho6.v = _mm256_setzero_pd(); + rho7.v = _mm256_setzero_pd(); + + // Broadcast real & imag parts of 4 elements of x + // to perform axpyf operation with 4x8 tile of A + x0R.v = _mm256_broadcast_sd( &chi0.real ); // real part of x0 + x0I.v = _mm256_broadcast_sd( &chi0.imag ); // imag part of x0 + x1R.v = _mm256_broadcast_sd( &chi1.real ); // real part of x1 + x1I.v = _mm256_broadcast_sd( &chi1.imag ); // imag part of x1 + x2R.v = _mm256_broadcast_sd( &chi2.real ); // real part of x2 + x2I.v = _mm256_broadcast_sd( &chi2.imag ); // imag part of x2 + x3R.v = _mm256_broadcast_sd( &chi3.real ); // real part of x3 + x3I.v = _mm256_broadcast_sd( &chi3.imag ); // imag part of x3 + + for ( ; ( row + 1 ) < m; row += 2) + { + // Load first two columns of A + // a_tile0.v -> a00R a00I a10R a10I + // a_tile1.v -> a01R a01I a11R a11I + a_tile0.v = _mm256_loadu_pd( (double *)&a[row + 0 * lda] ); + a_tile1.v = _mm256_loadu_pd( (double *)&a[row + 1 * lda] ); + + temp1.v = _mm256_mul_pd( a_tile0.v, x0R.v ); + temp2.v = _mm256_mul_pd( a_tile0.v, x0I.v ); + + temp1.v = _mm256_fmadd_pd( a_tile1.v, x1R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile1.v, x1I.v, temp2.v ); + + // Load w vector + // wR.v -> w0R w0I w1R w1I + // wI.v ( shuf wR.v ) -> w0I w0I w1I w1I + // wR.v ( shuf wR.v ) -> w0R w0R w1R w1R + wR.v = _mm256_loadu_pd( (double *)&w[row] ); + wI.v = _mm256_permute_pd( wR.v, 15 ); + wR.v = _mm256_permute_pd( wR.v, 0 ); + + rho0.v = _mm256_fmadd_pd( a_tile0.v, wR.v, rho0.v); + rho4.v = _mm256_fmadd_pd( a_tile0.v, wI.v, rho4.v); + + rho1.v = _mm256_fmadd_pd( a_tile1.v, wR.v, rho1.v); + rho5.v = _mm256_fmadd_pd( a_tile1.v, wI.v, rho5.v); + + // Load 3rd and 4th columns of A + // a_tile0.v -> a20R a20I a30R a30I + // a_tile1.v -> a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_pd( (double *)&a[row + 2 * lda] ); + a_tile1.v = _mm256_loadu_pd( (double *)&a[row + 3 * lda] ); + + temp1.v = _mm256_fmadd_pd( a_tile0.v, x2R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile0.v, x2I.v, temp2.v ); + + temp1.v = _mm256_fmadd_pd( a_tile1.v, x3R.v, temp1.v ); + temp2.v = _mm256_fmadd_pd( a_tile1.v, x3I.v, temp2.v ); + + rho2.v = _mm256_fmadd_pd( a_tile0.v, wR.v, rho2.v); + rho6.v = _mm256_fmadd_pd( a_tile0.v, wI.v, rho6.v); + + rho3.v = _mm256_fmadd_pd( a_tile1.v, wR.v, rho3.v); + rho7.v = _mm256_fmadd_pd( a_tile1.v, wI.v, rho7.v); + + // Load z vector + z_vec.v = _mm256_loadu_pd( (double *)&z[row] ); + + // Permute the result and alternatively add-sub final values + if( bli_is_noconj( conja ) ) + { + temp2.v = _mm256_permute_pd(temp2.v, 5); + temp3.v = _mm256_addsub_pd(temp1.v, temp2.v); + } + else + { + temp1.v = _mm256_permute_pd( temp1.v, 5 ); + temp3.v = _mm256_addsub_pd( temp2.v, temp1.v ); + temp3.v = _mm256_permute_pd( temp3.v, 5 ); + } + + // Add & store result to z_vec + z_vec.v = _mm256_add_pd( temp3.v, z_vec.v ); + _mm256_storeu_pd( (double *)&z[row], z_vec.v ); + } + + // Swapping position of real and imag component + // for horizontal addition to get the final + // dot product computation + // rho register are holding computation which needs + // to be arranged in following manner. + // a0R * x0I | a0I * x0I | a1R * x1I | a1I * x1R + // || + // \/ + // a0I * x0I | a0R * x0I | a1I * x1R | a1R * x1I + + rho4.v = _mm256_permute_pd(rho4.v, 0x05); + rho5.v = _mm256_permute_pd(rho5.v, 0x05); + rho6.v = _mm256_permute_pd(rho6.v, 0x05); + rho7.v = _mm256_permute_pd(rho7.v, 0x05); + + // Negating imaginary part for computing + // the final result of dcomplex multiplication + if ( bli_is_noconj( conjdot_use ) ) + { + rho4.v = _mm256_mul_pd(rho4.v, no_conju); + rho5.v = _mm256_mul_pd(rho5.v, no_conju); + rho6.v = _mm256_mul_pd(rho6.v, no_conju); + rho7.v = _mm256_mul_pd(rho7.v, no_conju); + } + else + { + rho4.v = _mm256_mul_pd(rho4.v, conju); + rho5.v = _mm256_mul_pd(rho5.v, conju); + rho6.v = _mm256_mul_pd(rho6.v, conju); + rho7.v = _mm256_mul_pd(rho7.v, conju); + } + + rho0.v = _mm256_add_pd(rho0.v, rho4.v); + rho1.v = _mm256_add_pd(rho1.v, rho5.v); + rho2.v = _mm256_add_pd(rho2.v, rho6.v); + rho3.v = _mm256_add_pd(rho3.v, rho7.v); + + // rho0 & rho1 hold final dot product + // result of 4 dcomplex elements + rho0.d[0] += rho0.d[2]; + rho0.d[1] += rho0.d[3]; + + rho0.d[2] = rho1.d[0] + rho1.d[2]; + rho0.d[3] = rho1.d[1] + rho1.d[3]; + + rho1.d[0] = rho2.d[0] + rho2.d[2]; + rho1.d[1] = rho2.d[1] + rho2.d[3]; + + rho1.d[2] = rho3.d[0] + rho3.d[2]; + rho1.d[3] = rho3.d[1] + rho3.d[3]; + + // Storing the computed dot product + // in temp buffer rho for further computation. + _mm256_storeu_pd( (double *)rho, rho0.v ); + _mm256_storeu_pd( (double *)(rho+2) , rho1.v ); + } + + // To handle the remaining cases + if ( rem ) + { + PRAGMA_SIMD + for ( dim_t p = row; p < m; ++p ) + { + const dcomplex a0c = a[p + 0 * lda]; + const dcomplex a1c = a[p + 1 * lda]; + const dcomplex a2c = a[p + 2 * lda]; + const dcomplex a3c = a[p + 3 * lda]; + + // dot + dcomplex r0c = rho[0]; + dcomplex r1c = rho[1]; + dcomplex r2c = rho[2]; + dcomplex r3c = rho[3]; + + dcomplex w0c = w[p]; + + r0c.real += a0c.real * w0c.real - a0c.imag * w0c.imag + * conjdotxf; + r0c.imag += a0c.imag * w0c.real + a0c.real * w0c.imag + * conjdotxf; + r1c.real += a1c.real * w0c.real - a1c.imag * w0c.imag + * conjdotxf; + r1c.imag += a1c.imag * w0c.real + a1c.real * w0c.imag + * conjdotxf; + r2c.real += a2c.real * w0c.real - a2c.imag * w0c.imag + * conjdotxf; + r2c.imag += a2c.imag * w0c.real + a2c.real * w0c.imag + * conjdotxf; + r3c.real += a3c.real * w0c.real - a3c.imag * w0c.imag + * conjdotxf; + r3c.imag += a3c.imag * w0c.real + a3c.real * w0c.imag + * conjdotxf; + + rho[0] = r0c; + rho[1] = r1c; + rho[2] = r2c; + rho[3] = r3c; + + // axpy + dcomplex z0c = z[p]; + + z0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag + * conjaxpyf; + z0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag + * conjaxpyf; + z0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag + * conjaxpyf; + z0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag + * conjaxpyf; + z0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag + * conjaxpyf; + z0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag + * conjaxpyf; + z0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag + * conjaxpyf; + z0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag + * conjaxpyf; + + z[p] = z0c; + } + } + + // Conjugating the final result if conjat + if ( bli_is_conj( conjat ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,conjs)( rho[i] ); + } + } + + // Scaling the dot product result with alpha + // and adding the result to vector y + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(z,axpys)( *alpha, rho[i], y[i] ); + } + } + else + { + // For non-unit increments + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(z,type); + PASTECH(z,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(z,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} + +/** + * cdotxaxpyf kernel performs dot and apxy function together. + * y := conj(beta) * y + conj(alpha) * conj(A)^t * conj(w) (dotxf) + * z := z + alpha * conj(A) * conj(x) (axpyf) + * where, + * A is an m x b matrix. + * w, z are vectors of length m. + * x, y are vectors of length b. + * alpha, beta are scalars + */ +void bli_cdotxaxpyf_zen_int_8 +( + conj_t conjat, + conj_t conja, + conj_t conjw, + conj_t conjx, + dim_t m, + dim_t b_n, + scomplex* restrict alpha, + scomplex* restrict a, inc_t inca, inc_t lda, + scomplex* restrict w, inc_t incw, + scomplex* restrict x, inc_t incx, + scomplex* restrict beta, + scomplex* restrict y, inc_t incy, + scomplex* restrict z, inc_t incz, + cntx_t* restrict cntx + ) +{ + // A: m x b + // w, z: m + // x, y: b + // + // y = beta * y + alpha * A^T w; + // z = z + alpha * A x; + if ( ( bli_cpuid_is_avx_supported() == TRUE ) && + ( inca == 1 ) && ( incw == 1 ) && ( incx == 1 ) + && ( incy == 1 ) && ( incz == 1 ) && ( b_n == 4 ) ) + { + // Temporary rho buffer holds computed dot product result + scomplex rho[ 4 ]; + + // chi? variables to hold scaled scaler values from x vector + scomplex chi0; + scomplex chi1; + scomplex chi2; + scomplex chi3; + + // If beta is zero, clear y + // Else, scale by beta + if ( PASTEMAC(c,eq0)( *beta ) ) + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(c,set0s)( y[i] ); + } + } + else + { + for ( dim_t i = 0; i < 4; ++i ) + { + PASTEMAC(c,scals)( *beta, y[i] ); + } + } + + // If the vectors are empty or if alpha is zero, return early + if ( bli_zero_dim1( m ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Initialize rho vector to 0 + for ( dim_t i = 0; i < 4; ++i ) PASTEMAC(c,set0s)( rho[i] ); + + // Set conj use variable for dot operation + conj_t conjdot_use = conjw; + if ( bli_is_conj( conjat ) ) + { + bli_toggle_conj( &conjdot_use ); + } + + // Set conj use variable for dotxf operation, scalar + dim_t conjdotxf = 1; + if ( bli_is_conj( conjdot_use ) ) + { + conjdotxf = -1; + } + + // Set conj use variable for axpyf operation, scalar + dim_t conjaxpyf = 1; + if ( bli_is_conj( conja ) ) + { + conjaxpyf = -1; + } + + // Store each element of x vector in a scalar and apply conjx + if( bli_is_noconj( conjx ) ) + { + chi0 = *( x + 0*incx ); + chi1 = *( x + 1*incx ); + chi2 = *( x + 2*incx ); + chi3 = *( x + 3*incx ); + } + else + { + bli_ccopycjs( conjx, *( x + 0*incx ), chi0 ); + bli_ccopycjs( conjx, *( x + 1*incx ), chi1 ); + bli_ccopycjs( conjx, *( x + 2*incx ), chi2 ); + bli_ccopycjs( conjx, *( x + 3*incx ), chi3 ); + } + + // Scale each chi scalar by alpha + bli_cscals( *alpha, chi0 ); + bli_cscals( *alpha, chi1 ); + bli_cscals( *alpha, chi2 ); + bli_cscals( *alpha, chi3 ); + + dim_t i = 0; + dim_t iter = m / 4; + dim_t rem = m % 4; + if (iter) + { + v8sf_t x0R, x1R, x2R, x3R; // x?R holds real part of x[?] + v8sf_t x0I, x1I, x2I, x3I; // x?I hold real part of x[?] + v8sf_t a_tile0, a_tile1; // a_tile? holds columns of a + v8sf_t temp1, temp2, temp3; // temp? registers for intermediate op + v8sf_t wR, wI; // holds real & imag components of w + v8sf_t z_vec; // holds the z vector + + // For final computation, based on conjdot_use + // sign of imaginary component needs to be toggled + __m256 no_conju = _mm256_setr_ps( -1, 1, -1, 1, -1, 1, -1, 1 ); + __m256 conju = _mm256_setr_ps( 1, -1, 1, -1, 1, -1, 1, -1 ); + + // Clear the temp registers + temp1.v = _mm256_setzero_ps(); + temp2.v = _mm256_setzero_ps(); + temp3.v = _mm256_setzero_ps(); + + // Clear rho registers + // Once micro tile is computed, horizontal addition + // of all rho's will provide us with the result of + // dotxf opereation + __m256 rho0v; rho0v = _mm256_setzero_ps(); + __m256 rho1v; rho1v = _mm256_setzero_ps(); + __m256 rho2v; rho2v = _mm256_setzero_ps(); + __m256 rho3v; rho3v = _mm256_setzero_ps(); + + __m256 rho4v; rho4v = _mm256_setzero_ps(); + __m256 rho5v; rho5v = _mm256_setzero_ps(); + __m256 rho6v; rho6v = _mm256_setzero_ps(); + __m256 rho7v; rho7v = _mm256_setzero_ps(); + + // Broadcast real & imag parts of 4 elements of x + // to perform axpyf operation with 4x8 tile of A + x0R.v = _mm256_broadcast_ss( &chi0.real ); // real part of x0 + x0I.v = _mm256_broadcast_ss( &chi0.imag ); // imag part of x0 + x1R.v = _mm256_broadcast_ss( &chi1.real ); // real part of x1 + x1I.v = _mm256_broadcast_ss( &chi1.imag ); // imag part of x1 + x2R.v = _mm256_broadcast_ss( &chi2.real ); // real part of x2 + x2I.v = _mm256_broadcast_ss( &chi2.imag ); // imag part of x2 + x3R.v = _mm256_broadcast_ss( &chi3.real ); // real part of x3 + x3I.v = _mm256_broadcast_ss( &chi3.imag ); // imag part of x3 + + for ( ; ( i + 3 ) < m; i += 4) + { + // Load first two columns of A + // a_tile0.v -> a00R a00I a10R a10I a20R a20I a30R a30I + // a_tile1.v -> a01R a01I a11R a11I a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_ps( (float *)&a[i + 0 * lda] ); + a_tile1.v = _mm256_loadu_ps( (float *)&a[i + 1 * lda] ); + + temp1.v = _mm256_mul_ps( a_tile0.v, x0R.v ); + temp2.v = _mm256_mul_ps( a_tile0.v, x0I.v ); + + temp1.v = _mm256_fmadd_ps( a_tile1.v, x1R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile1.v, x1I.v, temp2.v ); + + // Load w vector + // wR.v -> w0R w0I w1R w1I w2R w2I w3R w3I + // wI.v ( shuf wR.v ) -> w0I w0I w1I w1I w2I w2I w3I w3I + // wR.v ( shuf wR.v ) -> w0R w0R w1R w1R w2R w2R w3R w3R + wR.v = _mm256_loadu_ps( (float *) (w + i) ); + wI.v = _mm256_permute_ps( wR.v, 0xf5 ); + wR.v = _mm256_permute_ps( wR.v, 0xa0); + + rho0v = _mm256_fmadd_ps( a_tile0.v, wR.v, rho0v ); + rho4v = _mm256_fmadd_ps( a_tile0.v, wI.v, rho4v ); + + rho1v = _mm256_fmadd_ps( a_tile1.v, wR.v, rho1v ); + rho5v = _mm256_fmadd_ps( a_tile1.v, wI.v, rho5v ); + + // Load 3rd and 4th columns of A + // a_tile0.v -> a20R a20I a30R a30I + // a_tile1.v -> a21R a21I a31R a31I + a_tile0.v = _mm256_loadu_ps( (float *)&a[i + 2 * lda] ); + a_tile1.v = _mm256_loadu_ps( (float *)&a[i + 3 * lda] ); + + temp1.v = _mm256_fmadd_ps( a_tile0.v, x2R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile0.v, x2I.v, temp2.v ); + + temp1.v = _mm256_fmadd_ps( a_tile1.v, x3R.v, temp1.v ); + temp2.v = _mm256_fmadd_ps( a_tile1.v, x3I.v, temp2.v ); + + rho2v = _mm256_fmadd_ps( a_tile0.v, wR.v, rho2v ); + rho6v = _mm256_fmadd_ps( a_tile0.v, wI.v, rho6v ); + + rho3v = _mm256_fmadd_ps( a_tile1.v, wR.v, rho3v ); + rho7v = _mm256_fmadd_ps( a_tile1.v, wI.v, rho7v ); + + // Load z vector + z_vec.v = _mm256_loadu_ps( (float *)&z[i] ); + + // Permute the result and alternatively add-sub final values + if( bli_is_noconj( conja ) ) + { + temp2.v = _mm256_permute_ps(temp2.v, 0xB1); + temp3.v = _mm256_addsub_ps(temp1.v, temp2.v); + } + else + { + temp1.v = _mm256_permute_ps( temp1.v, 0xB1 ); + temp3.v = _mm256_addsub_ps( temp2.v, temp1.v ); + temp3.v = _mm256_permute_ps( temp3.v, 0xB1 ); + } + + // Add & store result to z_vec + z_vec.v = _mm256_add_ps( temp3.v, z_vec.v ); + _mm256_storeu_ps( (float *)&z[i], z_vec.v ); + } + + // Swapping position of real and imag component + // for horizontal addition to get the final + // dot product computation + // rho register are holding computation which needs + // to be arranged in following manner. + // a0R * x0I | a0I * x0I | a1R * x1I | a1I * x1R | ... + // || + // \/ + // a0I * x0I | a0R * x0I | a1I * x1R | a1R * x1I | ... + + rho4v = _mm256_permute_ps(rho4v, 0xb1); + rho5v = _mm256_permute_ps(rho5v, 0xb1); + rho6v = _mm256_permute_ps(rho6v, 0xb1); + rho7v = _mm256_permute_ps(rho7v, 0xb1); + + // Negating imaginary part for computing + // the final result of dcomplex multiplication + if ( bli_is_noconj( conjdot_use ) ) + { + rho4v = _mm256_mul_ps(rho4v, no_conju); + rho5v = _mm256_mul_ps(rho5v, no_conju); + rho6v = _mm256_mul_ps(rho6v, no_conju); + rho7v = _mm256_mul_ps(rho7v, no_conju); + } + else + { + rho4v = _mm256_mul_ps(rho4v, conju); + rho5v = _mm256_mul_ps(rho5v, conju); + rho6v = _mm256_mul_ps(rho6v, conju); + rho7v = _mm256_mul_ps(rho7v, conju); + } + + rho0v = _mm256_add_ps(rho0v, rho4v); + rho1v = _mm256_add_ps(rho1v, rho5v); + rho2v = _mm256_add_ps(rho2v, rho6v); + rho3v = _mm256_add_ps(rho3v, rho7v); + + // Horizontal addition of rho elements for computing final dotxf + // and storing the results into rho buffer + scomplex *ptr = (scomplex *)&rho0v; + for(dim_t j = 0; j < 4; j++) + { + rho[0].real += ptr[j].real; + rho[0].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho1v; + for(dim_t j = 0; j < 4; j++) + { + rho[1].real += ptr[j].real; + rho[1].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho2v; + for(dim_t j = 0; j < 4; j++) + { + rho[2].real += ptr[j].real; + rho[2].imag += ptr[j].imag; + } + ptr = (scomplex *)&rho3v; + for(dim_t j = 0; j < 4; j++) + { + rho[3].real += ptr[j].real; + rho[3].imag += ptr[j].imag; + } + } + + // To handle the remaining cases + if ( rem ) + { + PRAGMA_SIMD + for ( dim_t p = i; p < m; ++p ) + { + const scomplex a0c = a[p + 0 * lda]; + const scomplex a1c = a[p + 1 * lda]; + const scomplex a2c = a[p + 2 * lda]; + const scomplex a3c = a[p + 3 * lda]; + + // dot + scomplex r0c = rho[0]; + scomplex r1c = rho[1]; + scomplex r2c = rho[2]; + scomplex r3c = rho[3]; + + scomplex w0c = w[p]; + + r0c.real += a0c.real * w0c.real - a0c.imag * w0c.imag + * conjdotxf; + r0c.imag += a0c.imag * w0c.real + a0c.real * w0c.imag + * conjdotxf; + r1c.real += a1c.real * w0c.real - a1c.imag * w0c.imag + * conjdotxf; + r1c.imag += a1c.imag * w0c.real + a1c.real * w0c.imag + * conjdotxf; + r2c.real += a2c.real * w0c.real - a2c.imag * w0c.imag + * conjdotxf; + r2c.imag += a2c.imag * w0c.real + a2c.real * w0c.imag + * conjdotxf; + r3c.real += a3c.real * w0c.real - a3c.imag * w0c.imag + * conjdotxf; + r3c.imag += a3c.imag * w0c.real + a3c.real * w0c.imag + * conjdotxf; + + rho[0] = r0c; + rho[1] = r1c; + rho[2] = r2c; + rho[3] = r3c; + + // axpy + scomplex z0c = z[p]; + + z0c.real += chi0.real * a0c.real - chi0.imag * a0c.imag + * conjaxpyf; + z0c.real += chi1.real * a1c.real - chi1.imag * a1c.imag + * conjaxpyf; + z0c.real += chi2.real * a2c.real - chi2.imag * a2c.imag + * conjaxpyf; + z0c.real += chi3.real * a3c.real - chi3.imag * a3c.imag + * conjaxpyf; + z0c.imag += chi0.imag * a0c.real + chi0.real * a0c.imag + * conjaxpyf; + z0c.imag += chi1.imag * a1c.real + chi1.real * a1c.imag + * conjaxpyf; + z0c.imag += chi2.imag * a2c.real + chi2.real * a2c.imag + * conjaxpyf; + z0c.imag += chi3.imag * a3c.real + chi3.real * a3c.imag + * conjaxpyf; + + z[p] = z0c; + } + } + + // Conjugating the final result if conjat + if ( bli_is_conj( conjat ) ) + { + for ( dim_t j = 0; j < 4; ++j ) + { + PASTEMAC(c,conjs)( rho[j] ); + } + } + + // Scaling the dot product result with alpha + // and adding the result to vector y + for ( dim_t j = 0; j < 4; ++j ) + { + PASTEMAC(c,axpys)( *alpha, rho[j], y[j] ); + } + } + else + { + // For non-unit increments + /* Query the context for the kernel function pointer. */ + const num_t dt = PASTEMAC(c,type); + PASTECH(c,dotxf_ker_ft) kfp_df = + bli_cntx_get_l1f_ker_dt( dt, BLIS_DOTXF_KER, cntx ); + PASTECH(c,axpyf_ker_ft) kfp_af = + bli_cntx_get_l1f_ker_dt( dt, BLIS_AXPYF_KER, cntx ); + + kfp_df + ( + conjat, + conjw, + m, + b_n, + alpha, + a, inca, lda, + w, incw, + beta, + y, incy, + cntx + ); + + kfp_af + ( + conja, + conjx, + m, + b_n, + alpha, + a, inca, lda, + x, incx, + z, incz, + cntx + ); + } +} \ No newline at end of file diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 92ee71b2be..77d34807af 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -126,6 +126,8 @@ DOTXF_KER_PROT( dcomplex, z, dotxf_zen_int_6 ) DOTXF_KER_PROT( scomplex, c, dotxf_zen_int_6 ) // dotxaxpyf (intrinsics) DOTXAXPYF_KER_PROT( double, d, dotxaxpyf_zen_int_8 ) +DOTXAXPYF_KER_PROT( scomplex, c, dotxaxpyf_zen_int_8 ) +DOTXAXPYF_KER_PROT( dcomplex, z, dotxaxpyf_zen_int_8 ) // -- level-2 ---------------------------------------------------------------- From 1e43434713e96c96b58a788590d8290641c66ade Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Thu, 7 Apr 2022 00:18:38 -0500 Subject: [PATCH 41/63] Performance Improvement for ctrsm small sizes Details: - Enable ctrsm small implementation - Handled Overflow and Underflow Vulnerabilites in ctrsm small implementations. - Fixed failures observed in libflame testing. - For small sizes, ctrsm small implementation is used for all variants. Change-Id: I17b862dcb794a5af0ec68f585992131fef57b179 --- frame/compat/bla_trsm_amd.c | 9 +- kernels/zen/3/bli_trsm_small.c | 341 +++++++-------------------------- 2 files changed, 73 insertions(+), 277 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 9ff8073be0..e1a2fffafd 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -902,6 +902,8 @@ void dtrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } + + void ztrsm_ ( const f77_char* side, @@ -1221,7 +1223,8 @@ void ztrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#if 0 + + void ctrsm_ ( const f77_char* side, @@ -1236,7 +1239,7 @@ void ctrsm_ ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'c', *side, *uploa,*transa, *diaga, *m, *n, (void*)alpha,*lda, *ldb); @@ -1537,7 +1540,5 @@ void ctrsm_ /* Finalize BLIS. */ bli_finalize_auto(); } -#endif -INSERT_GENTFUNC_BLAS_C( trsm, trsm ) #endif diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index bb6d198c78..07077010f2 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -36852,33 +36852,19 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB */ #define SCOMPLEX_INV(a, b) {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - /*Compute denominator eliminating imaginary component*/\ - float dnm = (b.real * b.real);\ - /*multiply two times with -1 for correct result as - * dcomplex number with positive imaginary part will - * invert the sign if not multiplied twice with -1*/\ - dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\ - /*Compute the final result by dividing real and imag part by dnm*/\ - a.real /= dnm;\ - a.imag /= dnm;\ + a.real = 1.0;\ + a.imag = 0.0;\ + bli_cinvscals(b, a);\ } #define SCOMPLEX_MUL(a, b, c) {\ - float real = a.real * b.real;\ - real += ((a.imag * b.imag) * -1.0);\ - float imag = (a.real * b.imag);\ - imag += (a.imag * b.real);\ - c.real = real;\ - c.imag = imag;\ + c.real = b.real;\ + c.imag = b.imag;\ + bli_cscals(a,c);\ } #define SCOMPLEX_DIV(a, b){\ - float dnm = b.real * b.real;\ - dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\ - a.real /= dnm;\ - a.imag /= dnm;\ + bli_cinvscals(b,a); \ } #ifdef BLIS_ENABLE_TRSM_PREINVERSION @@ -36904,13 +36890,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define CTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\ - if(!is_unitdiag)\ - {\ - a.real = b.real;\ - a.imag = (b.imag * -1.0);\ - SCOMPLEX_MUL(c, a, c)\ - SCOMPLEX_DIV(c, b)\ - }\ + if(!is_unitdiag)\ + {\ + bli_cinvscals(b, c);\ + }\ } #endif @@ -37306,72 +37289,30 @@ BLIS_INLINE void ctrsm_small_pack_diag_element dim_t size ) { - __m256 ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm8; - bool is_eight = (size == 8) ? 1 : 0; - scomplex ones = {1.0, 1.0}; - ymm2 = ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); #ifdef BLIS_ENABLE_TRSM_PREINVERSION - __m256 ymm7; - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); -#endif - - if(!is_unitdiag) + // If Preinversion is disabled, inverse the diaganol + // elements from A and pack into diagonal buffer. + // In order to avoid the overflow and underflow scenarios, + // bli_cinvscals is used. + for( dim_t i = 0; i < size; i++) { - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_ps((__m128 const *)a11); - ymm2 = _mm256_broadcast_ps((__m128 const *) (a11+ cs_a +1)); - ymm3 = _mm256_broadcast_ps((__m128 const *) (a11+ cs_a*2 +2)); - - ymm1 = _mm256_shuffle_ps(ymm1, ymm2, 0x44); - - if(is_eight) { - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 4 + cs_a*4)); - ymm5 = _mm256_broadcast_ps((__m128 const *)(a11 + 5 + cs_a*5)); - ymm6 = _mm256_shuffle_ps(ymm4, ymm5, 0x44); - - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 6 + cs_a*6)); - ymm5 = _mm256_broadcast_ps((__m128 const *)(a11 + 7 + cs_a*7)); - ymm8 = _mm256_shuffle_ps(ymm4, ymm5, 0x44); - - ymm2 = _mm256_blend_ps(ymm6, ymm8, 0xF0); - - ymm4 = _mm256_broadcast_ps((__m128 const *)(a11 + 3 + cs_a*3)); - ymm3 = _mm256_shuffle_ps(ymm3, ymm4, 0x44); - } - - ymm1 = _mm256_blend_ps(ymm1, ymm3, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - /*Taking denomerator multiplication of real & imaginary components*/ - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm5 = _mm256_mul_ps(ymm2, ymm2); - /*Swapping real & imaginary component position for addition with - * respective components*/ - //BEFORE - //a[0] a[1] a[2] a[3] - //AFTER - //a[1] a[0] a[3] a[2] - //MESS - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm8 = _mm256_permute_ps(ymm5, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm5 = _mm256_add_ps(ymm5, ymm8); - - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm2 = _mm256_mul_ps(ymm2, ymm7); - - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_ps(ymm1, ymm4); - ymm2 = _mm256_div_ps(ymm2, ymm5); - -#endif + dim_t d = ((i*cs_a) + i); + scomplex ones = {1.0, 0.0}; + bli_cinvscals(a11[d], ones) + d11_pack[i].real = ones.real; + d11_pack[i].imag = ones.imag; } - _mm256_store_ps((float *)d11_pack, ymm1); - if(is_eight) + +#else //BLIS_ENABLE_TRSM_PREINVERSION + // If Preinversion is disabled, pack the diaganol + // elements from A into diagonal buffer. + for( dim_t i = 0; i < size; i++) { - _mm256_store_ps((float *)(d11_pack + 4), ymm2); + dim_t d = ((i*cs_a) + i); + bli_ccopys(a11[d],d11_pack[i]); } + +#endif //BLIS_ENABLE_TRSM_PREINVERSION } /** @@ -37619,26 +37560,19 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ ymm1 = _mm256_mul_ps(ymm1, ymm2);\ }\ - /*Negating imaginary component of numerator*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*BLIS_CTRSM_MUL(vec1)*/\ - /*BLIS_CTRSM_MUL(vec2)*/\ - /*vec1 * ymm1*/\ - ymm3 = _mm256_shuffle_ps(ymm1, ymm1, 0x11);\ - ymm2 = _mm256_shuffle_ps(vec1, vec1, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec1, vec1,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec1 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*vec1 * ymm1*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*Taking denomerator multiplication of real & imaginary components*/\ - ymm3 = _mm256_mul_ps(ymm1, ymm1);\ - ymm2 = _mm256_permute_ps(ymm3, 0xB1);\ - ymm3 = _mm256_add_ps(ymm2, ymm3);\ - /*Dividing numerator by denominator*/\ - vec1 = _mm256_div_ps(vec1, ymm3);\ + scomplex b_data[4];\ + scomplex d11_data[4];\ + \ + _mm256_storeu_ps((float *)(b_data), vec1);\ + _mm256_storeu_ps((float *)(d11_data), ymm1);\ + \ + for(dim_t i = 0; i < 4; i++)\ + {\ + bli_cinvscals(d11_data[0],b_data[i]);\ + }\ + \ + vec1 = _mm256_loadu_ps((float *)b_data);\ + \ }\ } @@ -37649,32 +37583,21 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ ymm1 = _mm256_mul_ps(ymm1, ymm2);\ }\ - /*Negating imaginary component of numerator*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*BLIS_CTRSM_MUL(vec1)*/\ - /*BLIS_CTRSM_MUL(vec2)*/\ - /*vec1 * ymm1*/\ - ymm3 = _mm256_shuffle_ps(ymm1, ymm1, 0x11);\ - ymm2 = _mm256_shuffle_ps(vec1, vec1, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec1, vec1,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec1 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*vec1 * ymm1*/\ - ymm2 = _mm256_shuffle_ps(vec2, vec2, 0xA0);\ - ymm16 = _mm256_shuffle_ps(vec2, vec2,0xF5);\ - ymm16 = _mm256_mul_ps(ymm16, ymm3);\ - vec2 = _mm256_fmaddsub_ps(ymm2, ymm1, ymm16);\ - /*done*/\ - ymm2 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0);\ - ymm1 = _mm256_mul_ps(ymm1, ymm2);\ - /*Taking denomerator multiplication of real & imaginary components*/\ - ymm3 = _mm256_mul_ps(ymm1, ymm1);\ - ymm2 = _mm256_permute_ps(ymm3, 0xB1);\ - ymm3 = _mm256_add_ps(ymm2, ymm3);\ - /*Dividing numerator by denominator*/\ - vec1 = _mm256_div_ps(vec1, ymm3);\ - vec2 = _mm256_div_ps(vec2, ymm3);\ + scomplex b_data[8];\ + scomplex d11_data[4];\ + \ + _mm256_storeu_ps((float *)(b_data), vec1);\ + _mm256_storeu_ps((float *)(b_data + 4), vec2);\ + _mm256_storeu_ps((float *)(d11_data), ymm1);\ + \ + for(dim_t i = 0; i < 8; i++)\ + {\ + bli_cinvscals(d11_data[0],b_data[i]);\ + }\ + \ + vec1 = _mm256_loadu_ps((float *)b_data);\ + vec2 = _mm256_loadu_ps((float *)(b_data+4));\ + \ }\ } @@ -40308,43 +40231,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); } - - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); - ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); - ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm256_storeu_ps((float *)(d11_pack), ymm1); for(j = 0; (j+d_nr-1) < n; j += d_nr) { @@ -42555,43 +42448,13 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+cs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*1 + 1)); - ymm2 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*2 + 2)); - ymm3 = _mm256_broadcast_ps((__m128 const *)(a11+rs_a*3 + 3)); - ymm0 = _mm256_permute_ps(ymm0, 0x44); - ymm1 = _mm256_permute_ps(ymm1, 0x44); - ymm2 = _mm256_permute_ps(ymm2, 0x44); - ymm3 = _mm256_permute_ps(ymm3, 0x44); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); } - - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); - ymm2 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); - ymm1 = _mm256_blend_ps(ymm1, ymm2, 0xF0); - -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm256_storeu_ps((float *)(d11_pack), ymm1); for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) { @@ -44147,30 +44010,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+cs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+rs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,n_rem); } - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = (m-d_mr); (i+1) > 0; i -= d_mr) { @@ -44626,25 +44472,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_blend_ps(ymm0, ymm1, 0xC0); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = (m-d_mr); (i+1) > 0; i -= d_mr) { @@ -44899,7 +44730,6 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB scomplex *a01, *a11, *b10, *b11; //pointers that point to blocks for GEMM and TRSM - scomplex ones = {1.0, 1.0}; bool is_unitdiag = bli_obj_has_unit_diag(a); //scratch registers @@ -45658,37 +45488,17 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB } } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+cs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_broadcast_ps((__m128 const *) - (a11+rs_a*1 + 1)); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,n_rem); } - ymm1 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = 0; (i+d_mr-1) < m; i += d_mr) { @@ -46153,25 +45963,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB } } - - ymm1 = _mm256_broadcast_ps((__m128 const *)&ones); - ymm1 = _mm256_permute_ps(ymm1, 0x44); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_ps((__m128 const *)(a11)); - ymm1 = _mm256_blend_ps(ymm0, ymm1, 0xC0); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); - ymm4 = _mm256_mul_ps(ymm1, ymm1); - ymm6 = _mm256_permute_ps(ymm4, 0xB1); - ymm4 = _mm256_add_ps(ymm4, ymm6); - ymm1 = _mm256_mul_ps(ymm1, ymm7); - ymm1 = _mm256_div_ps(ymm1, ymm4); -#endif + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,n_rem); } - _mm_store_ps((float *)(d11_pack), - _mm256_extractf128_ps(ymm1,0)); for(i = 0; (i+d_mr-1) < m; i += d_mr) { From e0c94ee56a55dbbc3bbcde1b68720121e49e0264 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 22 Mar 2022 11:48:25 +0530 Subject: [PATCH 42/63] Added AOCL progress support for BLIS MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit -- AOCL libraries are used for lengthy computations which can go on for hours or days, once the operation is started, the user doesn’t get any update on current state of the computation. This (AOCL progress) feature enables user to receive a periodic update from the libraries. -- User registers a callback with the library if it is interested in receiving the periodic update. -- The library invokes this callback periodically with information about current state of the operation. -- The update frequency is statically set in the code, it can be modified as needed if the library is built from source. -- These feature is supported for GEMM and TRSM operations. -- Added example for GEMM and TRSM. -- Cleaned up and reformatted test_gemm.c and test_trsm.c to remove warnings and making indentation consistent across the file. AMD-Internal: [CPUPL-2082] Change-Id: I2aacdd8fb76f52e19e3850ee0295df49a8b7a90e --- aocl_dtl/aocldtl.h | 3 +- aocl_dtl/aoclos.c | 13 +- aocl_dtl/aoclos.h | 4 +- frame/3/gemm/bli_gemm_ker_var2.c | 18 +- frame/3/trsm/bli_trsm_xx_ker_var2.c | 55 +- frame/include/bli_config_macro_defs.h | 8 +- frame/thread/bli_l3_decor_openmp.c | 13 +- frame/thread/bli_l3_decor_single.c | 15 +- frame/util/CMakeLists.txt | 3 +- frame/util/bli_util.h | 5 +- frame/util/bli_util_progress.c | 56 ++ frame/util/bli_util_progress.h | 74 +++ test/test_gemm.c | 787 ++++++++++++++------------ test/test_trsm.c | 681 +++++++++++----------- 14 files changed, 1018 insertions(+), 717 deletions(-) create mode 100644 frame/util/bli_util_progress.c create mode 100644 frame/util/bli_util_progress.h diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 58c1a56079..7ce81561b7 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -5,7 +5,7 @@ * It provides defination for all macros to be * used by user to add debug/trace information. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -15,6 +15,7 @@ #include "aocldtlcf.h" #include "aocltpdef.h" #include "aoclflist.h" +#include "aoclos.h" #define TRACE_TYPE_FENTRY (1) #define TRACE_TYPE_FEXIT (2) diff --git a/aocl_dtl/aoclos.c b/aocl_dtl/aoclos.c index 92a489564e..896b1c89b3 100644 --- a/aocl_dtl/aoclos.c +++ b/aocl_dtl/aoclos.c @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ #include "aocltpdef.h" @@ -85,8 +85,15 @@ uint64 AOCL_getTimestamp(void) #else /* Non linux support */ AOCL_TID AOCL_gettid(void) { - /* stub for other os's */ - return 0; +#ifdef BLIS_ENABLE_OPENMP + return omp_get_thread_num(); +#else +#ifdef BLIS_ENABLE_PTHREADS + return pthread_self(); +#else + return 0; +#endif +#endif } pid_t AOCL_getpid(void) diff --git a/aocl_dtl/aoclos.h b/aocl_dtl/aoclos.h index 3d8e1cddcc..57e0c24902 100644 --- a/aocl_dtl/aoclos.h +++ b/aocl_dtl/aoclos.h @@ -3,7 +3,7 @@ * * Description : Abstraction for os services used by DTL. * - * Copyright (C) 2020, Advanced Micro Devices, Inc + * Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -19,7 +19,7 @@ #define AOCL_malloc malloc #define AOCL_free free -uint32 AOCL_gettid(void); +AOCL_TID AOCL_gettid(void); pid_t AOCL_getpid(void); uint64 AOCL_getTimestamp(void); diff --git a/frame/3/gemm/bli_gemm_ker_var2.c b/frame/3/gemm/bli_gemm_ker_var2.c index 5e0a4ddb70..dc1c3d14dc 100644 --- a/frame/3/gemm/bli_gemm_ker_var2.c +++ b/frame/3/gemm/bli_gemm_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2019, Advanced Micro Devices, Inc. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -407,6 +407,22 @@ void PASTEMAC(ch,varname) \ } \ } \ \ +/* Send progress update if the user has enabled it */ \ +if(AOCL_progress_ptr) { \ + /* Running total for current thread */ \ + tls_aoclprogress_counter += m * n * k; \ + /* Send the update only if enough number of elements are processes */ \ + if ((tls_aoclprogress_counter - tls_aoclprogress_last_update) >= AOCL_PROGRESS_FREQUENCY) \ + { \ + tls_aoclprogress_last_update = tls_aoclprogress_counter; \ + AOCL_PROGRESS_DT(*MKSTR(ch), \ + "gemm", \ + tls_aoclprogress_counter, \ + AOCL_gettid(), \ + bli_rntm_num_threads(rntm)); \ + }\ +} \ + \ /* PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: b1", k, NR, b1, NR, 1, "%4.1f", "" ); \ PASTEMAC(ch,fprintm)( stdout, "gemm_ker_var2: a1", MR, k, a1, 1, MR, "%4.1f", "" ); \ diff --git a/frame/3/trsm/bli_trsm_xx_ker_var2.c b/frame/3/trsm/bli_trsm_xx_ker_var2.c index de8cad065a..8d2f8689a9 100644 --- a/frame/3/trsm/bli_trsm_xx_ker_var2.c +++ b/frame/3/trsm/bli_trsm_xx_ker_var2.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018 - 2020, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -87,6 +87,59 @@ void bli_trsm_xx_ker_var2 cntl, thread ); + + // Send progress update if enabled + if (AOCL_progress_ptr) + { + + // Get the size of block processed in + // this iteration, add it to the accumulated + // total and send the update. + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width(a); + + num_t dt = bli_obj_dt(c); + char dt_c; + + // Running total for current thread. + tls_aoclprogress_counter += m * n * k; + + // Send the update only if number of elements processes so far + // has exceeded the freqency of reporting. + if ((tls_aoclprogress_counter - tls_aoclprogress_last_update) >= + AOCL_PROGRESS_FREQUENCY) + { + + // reset the last update counter for next iteration. + tls_aoclprogress_last_update = tls_aoclprogress_counter; + + switch (dt) + { + case BLIS_FLOAT: + dt_c = 's'; + break; + case BLIS_DOUBLE: + dt_c = 'd'; + break; + case BLIS_SCOMPLEX: + dt_c = 'c'; + break; + case BLIS_DCOMPLEX: + dt_c = 'z'; + break; + default: + dt_c = ' '; + } + + AOCL_PROGRESS_DT(dt_c, + "trsm", + tls_aoclprogress_counter, + AOCL_gettid(), + bli_rntm_num_threads(rntm)); + } + } + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_6); } diff --git a/frame/include/bli_config_macro_defs.h b/frame/include/bli_config_macro_defs.h index d00df2f0be..c9e597c9a6 100644 --- a/frame/include/bli_config_macro_defs.h +++ b/frame/include/bli_config_macro_defs.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -260,5 +260,11 @@ #endif +#ifdef BLIS_OS_WINDOWS + #define BLIS_TLS_TYPE __declspec(thread) +#else + #define BLIS_TLS_TYPE __thread +#endif + #endif diff --git a/frame/thread/bli_l3_decor_openmp.c b/frame/thread/bli_l3_decor_openmp.c index 0bf3ad8547..b01c208a30 100644 --- a/frame/thread/bli_l3_decor_openmp.c +++ b/frame/thread/bli_l3_decor_openmp.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -140,6 +140,17 @@ void bli_l3_thread_decorator bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); #if 1 + // Reset the progress state to 0 as we are starting new operations. + // This counter track running progress in current thread. + tls_aoclprogress_counter = 0; + + // We send the update only after certain threshold is reached, + // The thresold is defined as AOCL_PROGRESS_FREQUENCY. + // This variable stores the counter value when last update was sent. + // It is compared with current counter value to see if it is time to + // send the next update. + tls_aoclprogress_last_update = 0; + func ( alpha, diff --git a/frame/thread/bli_l3_decor_single.c b/frame/thread/bli_l3_decor_single.c index 12f27ad873..444583e73e 100644 --- a/frame/thread/bli_l3_decor_single.c +++ b/frame/thread/bli_l3_decor_single.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2018, Advanced Micro Devices, Inc. + Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -115,7 +115,18 @@ void bli_l3_thread_decorator // Create the root node of the thread's thrinfo_t structure. bli_l3_thrinfo_create_root( tid, gl_comm, rntm_p, cntl_use, &thread ); - + + // Reset the progress state to 0 as we are starting new operations. + // This counter track running progress in current thread. + tls_aoclprogress_counter = 0; + + // We send the update only after certain threshold is reached, + // The thresold is defined as AOCL_PROGRESS_FREQUENCY. + // This variable stores the counter value when last update was sent. + // It is compared with current counter value to see if it is time to + // send the next update. + tls_aoclprogress_last_update = 0; + func ( alpha, diff --git a/frame/util/CMakeLists.txt b/frame/util/CMakeLists.txt index c20d7c525d..13fd53fc52 100644 --- a/frame/util/CMakeLists.txt +++ b/frame/util/CMakeLists.txt @@ -1,4 +1,4 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE @@ -13,4 +13,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_unb_var1.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_update.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_api_wrap.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_util_progress.c ) diff --git a/frame/util/bli_util.h b/frame/util/bli_util.h index 3c4e5722af..f7be273526 100644 --- a/frame/util/bli_util.h +++ b/frame/util/bli_util.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -63,3 +63,6 @@ // Header file define different formats of BLAS APIs- uppercase with // and without underscore, lowercase without underscore. #include "bli_util_api_wrap.h" + +// Public interface for the progress feature +#include "bli_util_progress.h" \ No newline at end of file diff --git a/frame/util/bli_util_progress.c b/frame/util/bli_util_progress.c new file mode 100644 index 0000000000..4097eb1126 --- /dev/null +++ b/frame/util/bli_util_progress.c @@ -0,0 +1,56 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +// The progress feature periodically updates the user with current state +// of the operation, We maintain the progress for each thread separately +// following variables are used to store the elements processed in each +// thread using thread local storage. +BLIS_TLS_TYPE dim_t tls_aoclprogress_counter; + +// Store the counter when last update was sent, this is used to implement +// update freqency. +BLIS_TLS_TYPE dim_t tls_aoclprogress_last_update; + + +// AOCL_progress_ptr contains the pointer to the callback function +// By default it is set to NULL, which effectivly disabled the +// progress feature. +AOCL_progress_callback AOCL_progress_ptr = NULL; + +void AOCL_BLIS_set_progress(AOCL_progress_callback func) +{ + AOCL_progress_ptr = func; +} \ No newline at end of file diff --git a/frame/util/bli_util_progress.h b/frame/util/bli_util_progress.h new file mode 100644 index 0000000000..0e2a63eb1c --- /dev/null +++ b/frame/util/bli_util_progress.h @@ -0,0 +1,74 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifndef BLI_UTIL_PROGRESS_H +#define BLI_UTIL_PROGRESS_H + +// Public interface for the end user. + +typedef dim_t (*AOCL_progress_callback)(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads); + +BLIS_EXPORT_BLIS void AOCL_BLIS_set_progress(AOCL_progress_callback func); + +// Private interfaces for internal use + +extern AOCL_progress_callback AOCL_progress_ptr; + +extern BLIS_TLS_TYPE dim_t tls_aoclprogress_counter; +extern BLIS_TLS_TYPE dim_t tls_aoclprogress_last_update; + +// Define the frequency of reporting (number of elements). +// Progress update will be sent only after these many +// elements are processed in the current thread. +#define AOCL_PROGRESS_FREQUENCY 1e+9 + +#define MAX_API_NAME_LEN 20 + +// Macro to send update using datatype character and the api name +#define AOCL_PROGRESS_DT(dt, api, progress, tid, nt) \ + char buf[MAX_API_NAME_LEN]; \ + snprintf(buf, MAX_API_NAME_LEN, "%c%s", dt, api); \ + (*AOCL_progress_ptr) (buf, strlen(buf), progress, tid, nt); \ + +// Macro to send update using api name alone. +#define AOCL_PROGRESS_NAME(api, progress, tid, nt) \ + char buf[MAX_API_NAME_LEN]; \ + snprintf(buf, MAX_API_NAME_LEN, "%s", dt, api); \ + (*AOCL_progress_ptr) (buf, strlen(buf), progress, tid, nt); \ + +#endif // BLI_UTIL_PROGRESS_H diff --git a/test/test_gemm.c b/test/test_gemm.c index 25fc5e3d8d..81b7e36616 100644 --- a/test/test_gemm.c +++ b/test/test_gemm.c @@ -10,14 +10,14 @@ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name of The University of Texas nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name of The University of Texas nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -47,426 +47,471 @@ // uncomment to enable cblas interface //#define CBLAS -int main( int argc, char** argv ) +// Uncomment to enable progress printing. +//#define PROGRESS_ENABLED + +#ifdef PROGRESS_ENABLED +dim_t AOCL_progress(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads) +{ + printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", + api, lapi, total_threads, current_thread, progress); + + return 0; +} +#endif + +int main(int argc, char **argv) { - obj_t a, b, c; - obj_t c_save; - obj_t alpha, beta; - dim_t m, n, k; - inc_t lda, ldb, ldc; - num_t dt, dt_a; - inc_t r, n_repeats; - trans_t transa; - trans_t transb; - f77_char f77_transa; - f77_char f77_transb; - - double dtime; - double dtime_save; - double gflops; - - //bli_init(); - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); - - n_repeats = 300; - - //dt = BLIS_FLOAT; - dt = BLIS_DOUBLE; - //dt = BLIS_SCOMPLEX; - //dt = BLIS_DCOMPLEX; - - if( bli_is_real( dt ) || bli_is_scomplex( dt ) ) + obj_t a, b, c; + obj_t c_save; + obj_t alpha, beta; + dim_t m, n, k; + inc_t lda, ldb, ldc; + num_t dt, dt_a; + inc_t r, n_repeats; + trans_t transa; + trans_t transb; + f77_char f77_transa; + f77_char f77_transb; + + double dtime; + double dtime_save; + double gflops; + +#ifdef PROGRESS_ENABLED + AOCL_BLIS_set_progress(AOCL_progress); +#endif + + // bli_init(); + // bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + + n_repeats = 300; + + // dt = BLIS_FLOAT; + dt = BLIS_DOUBLE; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; + + if (bli_is_real(dt) || bli_is_scomplex(dt)) dt_a = dt; else { dt_a = dt; // Enable the following to call - // dzgemm - //dt_a = BLIS_DOUBLE; + // dzgemm + // dt_a = BLIS_DOUBLE; } const char stor_scheme = 'C'; - transa = BLIS_NO_TRANSPOSE; - transb = BLIS_NO_TRANSPOSE; - - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + transa = BLIS_NO_TRANSPOSE; + transb = BLIS_NO_TRANSPOSE; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_trans(transb, &f77_transb); printf("BLIS Library version is : %s\n", bli_info_get_version_str()); #ifdef FILE_IN_OUT - FILE* fin = NULL; - FILE* fout = NULL; - if (argc < 3){ - printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); - exit(1); - } - fin = fopen(argv[1], "r"); - if (fin == NULL){ - printf("Error opening the file %s\n", argv[1]); - exit(1); - } - fout = fopen(argv[2], "w"); - if (fout == NULL){ - printf("Error opening output file %s\n", argv[2]); - exit(1); - } - fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\n"); - printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\n"); - - while (fscanf(fin, "%lld %lld %lld %lld %lld %lld\n", &m, &k, &n, &lda, &ldb, &ldc) == 6) - { - // dimensions should not be greater than leading dimensions - // These are valid only when Op(A) = n and op(B) = n - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - if ((m > lda) || (k > ldb) || (m > ldc)) continue; - }else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) { - // leading dimension should be greater than number of cols - if ((k > lda) || (n > ldb) || (n > ldc)) continue; - }else { - printf("Invalid Storage type\n"); - continue; - } + FILE *fin = NULL; + FILE *fout = NULL; + if (argc < 3) + { + printf("Usage: ./test_gemm_XX.x input.csv output.csv\n"); + exit(1); + } + fin = fopen(argv[1], "r"); + if (fin == NULL) + { + printf("Error opening the file %s\n", argv[1]); + exit(1); + } + fout = fopen(argv[2], "w"); + if (fout == NULL) + { + printf("Error opening output file %s\n", argv[2]); + exit(1); + } + fprintf(fout, "m\t k\t n\t cs_a\t cs_b\t cs_c\t gflops\n"); + printf("~~~~~~~~~~_BLAS\t m\t k\t n\t cs_a\t cs_b\t cs_c \t gflops\n"); + + while (fscanf(fin, "%ld %ld %ld %ld %ld %ld\n", &m, &k, &n, &lda, &ldb, &ldc) == 6) + { + // dimensions should not be greater than leading dimensions + // These are valid only when Op(A) = n and op(B) = n + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + if ((m > lda) || (k > ldb) || (m > ldc)) + continue; + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + // leading dimension should be greater than number of cols + if ((k > lda) || (n > ldb) || (n > ldc)) + continue; + } + else + { + printf("Invalid Storage type\n"); + continue; + } #else - dim_t p, p_begin, p_end, p_inc; - dim_t m_input, n_input, k_input; - p_begin = 200; - p_end = 2000; - p_inc = 200; - - m_input = n_input = k_input = -1; - for ( p = p_begin; p <= p_end; p += p_inc ) - { - if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; - if ( k_input < 0 ) k = p * ( dim_t )abs(k_input); - else k = ( dim_t ) k_input; - - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) { - lda = m; ldb = k, ldc = m; - }else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) { - lda = k; ldb = n, ldc = n; - } + dim_t p, p_begin, p_end, p_inc; + dim_t m_input, n_input, k_input; + p_begin = 200; + p_end = 2000; + p_inc = 200; + + m_input = n_input = k_input = -1; + for (p = p_begin; p <= p_end; p += p_inc) + { + if (m_input < 0) + m = p * (dim_t)labs(m_input); + else + m = (dim_t)m_input; + if (n_input < 0) + n = p * (dim_t)labs(n_input); + else + n = (dim_t)n_input; + if (k_input < 0) + k = p * (dim_t)labs(k_input); + else + k = (dim_t)k_input; + + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + lda = m; + ldb = k, ldc = m; + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + lda = k; + ldb = n, ldc = n; + } #endif - bli_obj_create( dt, 1, 1, 0, 0, &alpha); - bli_obj_create( dt, 1, 1, 0, 0, &beta ); - - siz_t elem_size = bli_dt_size( dt ); - - lda = bli_align_dim_to_size( lda, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - ldb = bli_align_dim_to_size( ldb, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - ldc = bli_align_dim_to_size( ldc, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - - // Will verify the leading dimension is powers of 2 and add 64bytes. - inc_t n_bytes = lda*sizeof(dt_a); - - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - lda += BLIS_SIMD_ALIGN_SIZE/sizeof(dt_a); - - n_bytes = ldb*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - ldb += BLIS_SIMD_ALIGN_SIZE/sizeof(dt); - - n_bytes = ldc*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - ldc += BLIS_SIMD_ALIGN_SIZE/sizeof(dt); - - if( (stor_scheme == 'C') || (stor_scheme == 'c') ) - { - // Col-major Order - bli_obj_create( dt_a, m, k, 1, lda, &a ); - bli_obj_create( dt, k, n, 1, ldb, &b ); - bli_obj_create( dt, m, n, 1, ldc, &c ); - bli_obj_create( dt, m, n, 1, ldc, &c_save ); - } - else if( (stor_scheme == 'R') || (stor_scheme == 'r') ) - { - // Row-major Order - bli_obj_create( dt_a, m, k, lda, 1, &a ); - bli_obj_create( dt, k, n, ldb, 1, &b ); - bli_obj_create( dt, m, n, ldc, 1, &c ); - bli_obj_create( dt, m, n, ldc, 1, &c_save ); - } - else - { - printf("Invalid Storage type\n"); - continue; - } + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_obj_create(dt, 1, 1, 0, 0, &beta); + + siz_t elem_size = bli_dt_size(dt); + + lda = bli_align_dim_to_size(lda, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + ldb = bli_align_dim_to_size(ldb, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + ldc = bli_align_dim_to_size(ldc, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + + // Will verify the leading dimension is powers of 2 and add 64bytes. + inc_t n_bytes = lda * sizeof(dt_a); + + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + lda += BLIS_SIMD_ALIGN_SIZE / sizeof(dt_a); + + n_bytes = ldb * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + ldb += BLIS_SIMD_ALIGN_SIZE / sizeof(dt); + + n_bytes = ldc * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + ldc += BLIS_SIMD_ALIGN_SIZE / sizeof(dt); + + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + // Col-major Order + bli_obj_create(dt_a, m, k, 1, lda, &a); + bli_obj_create(dt, k, n, 1, ldb, &b); + bli_obj_create(dt, m, n, 1, ldc, &c); + bli_obj_create(dt, m, n, 1, ldc, &c_save); + } + else if ((stor_scheme == 'R') || (stor_scheme == 'r')) + { + // Row-major Order + bli_obj_create(dt_a, m, k, lda, 1, &a); + bli_obj_create(dt, k, n, ldb, 1, &b); + bli_obj_create(dt, m, n, ldc, 1, &c); + bli_obj_create(dt, m, n, ldc, 1, &c_save); + } + else + { + printf("Invalid Storage type\n"); + continue; + } #ifdef MATRIX_INITIALISATION - bli_randm( &a ); - bli_randm( &b ); - bli_randm( &c ); + bli_randm(&a); + bli_randm(&b); + bli_randm(&c); #endif - bli_obj_set_conjtrans( transa, &a); - bli_obj_set_conjtrans( transb, &b); - bli_setsc( (0.9/1.0), 0.2, &alpha ); - bli_setsc( -(1.1/1.0), 0.3, &beta ); - - bli_copym( &c, &c_save ); - dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) - { - bli_copym( &c_save, &c ); - dtime = bli_clock(); + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_conjtrans(transb, &b); + bli_setsc((0.9 / 1.0), 0.2, &alpha); + bli_setsc(-(1.1 / 1.0), 0.3, &beta); + + bli_copym(&c, &c_save); + dtime_save = DBL_MAX; + for (r = 0; r < n_repeats; ++r) + { + bli_copym(&c_save, &c); + dtime = bli_clock(); #ifdef BLIS - bli_gemm( &alpha, - &a, - &b, - &beta, - &c ); + bli_gemm(&alpha, + &a, + &b, + &beta, + &c); #else - f77_int lda, ldb, ldc; - f77_int mm = bli_obj_length( &c ); - f77_int kk = bli_obj_width_after_trans( &a ); - f77_int nn = bli_obj_width( &c ); + f77_int lda, ldb, ldc; + f77_int mm = bli_obj_length(&c); + f77_int kk = bli_obj_width_after_trans(&a); + f77_int nn = bli_obj_width(&c); #ifdef CBLAS - enum CBLAS_ORDER cblas_order; - enum CBLAS_TRANSPOSE cblas_transa; - enum CBLAS_TRANSPOSE cblas_transb; - - if ( bli_obj_row_stride( &c ) == 1 ){ - cblas_order = CblasColMajor; - }else{ - cblas_order = CblasRowMajor; - } - - if( bli_is_trans( transa ) ) - cblas_transa = CblasTrans; - else if( bli_is_conjtrans( transa ) ) - cblas_transa = CblasConjTrans; - else - cblas_transa = CblasNoTrans; - - if( bli_is_trans( transb ) ) - cblas_transb = CblasTrans; - else if( bli_is_conjtrans( transb ) ) - cblas_transb = CblasConjTrans; - else - cblas_transb = CblasNoTrans; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_TRANSPOSE cblas_transb; + + if (bli_obj_row_stride(&c) == 1) + { + cblas_order = CblasColMajor; + } + else + { + cblas_order = CblasRowMajor; + } + + if (bli_is_trans(transa)) + cblas_transa = CblasTrans; + else if (bli_is_conjtrans(transa)) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if (bli_is_trans(transb)) + cblas_transb = CblasTrans; + else if (bli_is_conjtrans(transb)) + cblas_transb = CblasConjTrans; + else + cblas_transb = CblasNoTrans; #else - f77_char f77_transa; - f77_char f77_transb; - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_trans( transb, &f77_transb ); + f77_char f77_transa; + f77_char f77_transb; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_trans(transb, &f77_transb); #endif - if( (stor_scheme == 'C') || (stor_scheme == 'c') ){ - lda = bli_obj_col_stride( &a ); - ldb = bli_obj_col_stride( &b ); - ldc = bli_obj_col_stride( &c ); - } else { - lda = bli_obj_row_stride( &a ); - ldb = bli_obj_row_stride( &b ); - ldc = bli_obj_row_stride( &c ); - } - - if ( bli_is_float( dt ) ) - { - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* bp = bli_obj_buffer( &b ); - float* betap = bli_obj_buffer( &beta ); - float* cp = bli_obj_buffer( &c ); + if ((stor_scheme == 'C') || (stor_scheme == 'c')) + { + lda = bli_obj_col_stride(&a); + ldb = bli_obj_col_stride(&b); + ldc = bli_obj_col_stride(&c); + } + else + { + lda = bli_obj_row_stride(&a); + ldb = bli_obj_row_stride(&b); + ldc = bli_obj_row_stride(&c); + } + + if (bli_is_float(dt)) + { + float *alphap = bli_obj_buffer(&alpha); + float *ap = bli_obj_buffer(&a); + float *bp = bli_obj_buffer(&b); + float *betap = bli_obj_buffer(&beta); + float *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_sgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - *alphap, - ap, lda, - bp, ldb, - *betap, - cp, ldc - ); + cblas_sgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc); #else - sgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + sgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_double( dt ) ) - { - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* bp = bli_obj_buffer( &b ); - double* betap = bli_obj_buffer( &beta ); - double* cp = bli_obj_buffer( &c ); + } + else if (bli_is_double(dt)) + { + double *alphap = bli_obj_buffer(&alpha); + double *ap = bli_obj_buffer(&a); + double *bp = bli_obj_buffer(&b); + double *betap = bli_obj_buffer(&beta); + double *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_dgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - *alphap, - ap, lda, - bp, ldb, - *betap, - cp, ldc - ); + cblas_dgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + *alphap, + ap, lda, + bp, ldb, + *betap, + cp, ldc); #else - dgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + dgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_scomplex( dt ) ) - { - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* bp = bli_obj_buffer( &b ); - scomplex* betap = bli_obj_buffer( &beta ); - scomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_scomplex(dt)) + { + scomplex *alphap = bli_obj_buffer(&alpha); + scomplex *ap = bli_obj_buffer(&a); + scomplex *bp = bli_obj_buffer(&b); + scomplex *betap = bli_obj_buffer(&beta); + scomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_cgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - alphap, - ap, lda, - bp, ldb, - betap, - cp, ldc - ); + cblas_cgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc); #else - cgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); + cgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); #endif - }else if ( bli_is_dcomplex( dt ) ) - { - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* bp = bli_obj_buffer( &b ); - dcomplex* betap = bli_obj_buffer( &beta ); - dcomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_dcomplex(dt)) + { + dcomplex *alphap = bli_obj_buffer(&alpha); + dcomplex *ap = bli_obj_buffer(&a); + dcomplex *bp = bli_obj_buffer(&b); + dcomplex *betap = bli_obj_buffer(&beta); + dcomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_zgemm( cblas_order, - cblas_transa, - cblas_transb, - mm, - nn, - kk, - alphap, - ap, lda, - bp, ldb, - betap, - cp, ldc - ); + cblas_zgemm(cblas_order, + cblas_transa, + cblas_transb, + mm, + nn, + kk, + alphap, + ap, lda, + bp, ldb, + betap, + cp, ldc); #else #if 1 - if( bli_is_double( dt_a ) ) - { - dzgemm_( - &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - (double*)ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc - ); - } - else - { - zgemm_( &f77_transa, - &f77_transb, - &mm, - &nn, - &kk, - alphap, - ap, (f77_int*)&lda, - bp, (f77_int*)&ldb, - betap, - cp, (f77_int*)&ldc ); - } + if (bli_is_double(dt_a)) + { + dzgemm_( + &f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + (double *)ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); + } + else + { + zgemm_(&f77_transa, + &f77_transb, + &mm, + &nn, + &kk, + alphap, + ap, (f77_int *)&lda, + bp, (f77_int *)&ldb, + betap, + cp, (f77_int *)&ldc); + } #endif #endif - } + } #endif #ifdef PRINT - bli_printm( "a", &a, "%4.1f", "" ); - bli_printm( "b", &b, "%4.1f", "" ); - bli_printm( "c", &c, "%4.1f", "" ); - bli_printm( "c after", &c, "%4.1f", "" ); - exit(1); + bli_printm("a", &a, "%4.1f", ""); + bli_printm("b", &b, "%4.1f", ""); + bli_printm("c", &c, "%4.1f", ""); + bli_printm("c after", &c, "%4.1f", ""); + exit(1); #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); - }//nrepeats + dtime_save = bli_clock_min_diff(dtime_save, dtime); + } // nrepeats - gflops = ( 2.0 * m * k * n ) / ( dtime_save * 1.0e9 ); - if (bli_is_dcomplex(dt) && (bli_is_double(dt_a))) - gflops *= 2.0; - else if ( bli_is_complex( dt ) ) gflops *= 4.0; + gflops = (2.0 * m * k * n) / (dtime_save * 1.0e9); + if (bli_is_dcomplex(dt) && (bli_is_double(dt_a))) + gflops *= 2.0; + else if (bli_is_complex(dt)) + gflops *= 4.0; #ifdef BLIS - printf("data_gemm_blis" ); + printf("data_gemm_blis"); #else - printf("data_gemm_%s", BLAS ); + printf("data_gemm_%s", BLAS); #endif - #ifdef FILE_IN_OUT - printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ - ( unsigned long )m,( unsigned long )k,( unsigned long )n, - (unsigned long)lda,(unsigned long)ldb,(unsigned long)ldc,gflops); + printf("%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", + (unsigned long)m, (unsigned long)k, (unsigned long)n, + (unsigned long)lda, (unsigned long)ldb, (unsigned long)ldc, gflops); - fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", \ - ( unsigned long )m,( unsigned long )k,( unsigned long )n, - (unsigned long)lda,(unsigned long)ldb,(unsigned long)ldc,gflops); - fflush(fout); + fprintf(fout, "%6lu \t %4lu \t %4lu \t %4lu \t %4lu \t %4lu \t %6.3f\n", + (unsigned long)m, (unsigned long)k, (unsigned long)n, + (unsigned long)lda, (unsigned long)ldb, (unsigned long)ldc, gflops); + fflush(fout); #else - printf( "( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )m,( unsigned long )k, - ( unsigned long )n, gflops ); + printf("( %2lu, 1:4 ) = [ %4lu %4lu %4lu %7.2f ];\n", + (unsigned long)(p - p_begin) / p_inc + 1, + (unsigned long)m, (unsigned long)k, + (unsigned long)n, gflops); #endif - bli_obj_free( &alpha ); - bli_obj_free( &beta ); + bli_obj_free(&alpha); + bli_obj_free(&beta); - bli_obj_free( &a ); - bli_obj_free( &b ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); - }//while + bli_obj_free(&a); + bli_obj_free(&b); + bli_obj_free(&c); + bli_obj_free(&c_save); + } // while - //bli_finalize(); + // bli_finalize(); #ifdef FILE_IN_OUT - fclose(fin); - fclose(fout); + fclose(fin); + fclose(fout); #endif - return 0; + return 0; } diff --git a/test/test_trsm.c b/test/test_trsm.c index 72156d92fe..f6709f5d7f 100644 --- a/test/test_trsm.c +++ b/test/test_trsm.c @@ -5,19 +5,19 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - - Neither the name(s) of the copyright holder(s) nor the names of its - contributors may be used to endorse or promote products derived - from this software without specific prior written permission. + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT @@ -50,14 +50,31 @@ #define CACHE_LINE_SIZE 64 -int main( int argc, char** argv ) +// Uncomment to enable progress printing. +//#define PROGRESS_ENABLED + +#ifdef PROGRESS_ENABLED +dim_t AOCL_progress(char *api, + dim_t lapi, + dim_t progress, + dim_t current_thread, + dim_t total_threads) +{ + printf("\n%s, len = %ld, nt = %ld, tid = %ld, Processed %ld Elements", + api, lapi, total_threads, current_thread, progress); + + return 0; +} +#endif + +int main(int argc, char **argv) { obj_t a, c; obj_t c_save; obj_t alpha; dim_t m, n; num_t dt; - int r, n_repeats; + int r, n_repeats; side_t side; uplo_t uploa; trans_t transa; @@ -72,16 +89,20 @@ int main( int argc, char** argv ) double gflops; #ifdef FILE_IN_OUT - FILE* fin = NULL; - FILE* fout = NULL; + FILE *fin = NULL; + FILE *fout = NULL; #else dim_t p; dim_t p_begin, p_end, p_inc; - int m_input, n_input; + int m_input, n_input; - //bli_init(); +#ifdef PROGRESS_ENABLED + AOCL_BLIS_set_progress(AOCL_progress); +#endif + + // bli_init(); - //bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); + // bli_error_checking_level_set( BLIS_NO_ERROR_CHECKING ); #ifndef PRINT p_begin = 200; @@ -102,26 +123,26 @@ int main( int argc, char** argv ) n_repeats = 3; - //dt = BLIS_FLOAT; + // dt = BLIS_FLOAT; dt = BLIS_DOUBLE; - //dt = BLIS_SCOMPLEX; - //dt = BLIS_DCOMPLEX; + // dt = BLIS_SCOMPLEX; + // dt = BLIS_DCOMPLEX; #ifdef FILE_IN_OUT - if(argc < 3) + if (argc < 3) { printf("Usage: ./test_trsm_XX.x input.csv output.csv\n"); exit(1); } fin = fopen(argv[1], "r"); - if(fin == NULL) + if (fin == NULL) { printf("Error opening the file %s\n", argv[1]); exit(1); } fout = fopen(argv[2], "w"); - if(fout == NULL) + if (fout == NULL) { printf("Error opening the file %s\n", argv[2]); exit(1); @@ -130,425 +151,421 @@ int main( int argc, char** argv ) inc_t cs_b; #ifdef READ_ALL_PARAMS_FROM_FILE char side_c, uploa_c, transa_c, diaga_c; - + fprintf(fout, "side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); printf("~~~~~~~_BLAS\t side, uploa, transa, diaga, m\t n\t cs_a\t cs_b\t gflops\n"); - while(fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) + while (fscanf(fin, "%c %c %c %c %ld %ld %ld %ld\n", &side_c, &uploa_c, &transa_c, &diaga_c, &m, &n, &cs_a, &cs_b) == 8) { - if( 'l' == side_c|| 'L' == side_c) - side = BLIS_LEFT; - else if('r' == side_c || 'R' == side_c) - side = BLIS_RIGHT; - else - { - printf("Invalid entry for the argument 'side':%c\n",side_c); - continue; - } + if ('l' == side_c || 'L' == side_c) + side = BLIS_LEFT; + else if ('r' == side_c || 'R' == side_c) + side = BLIS_RIGHT; + else + { + printf("Invalid entry for the argument 'side':%c\n", side_c); + continue; + } - if('l' == uploa_c || 'L' == uploa_c) - uploa = BLIS_LOWER; - else if('u' == uploa_c || 'U' == uploa_c) - uploa = BLIS_UPPER; - else - { - printf("Invalid entry for the argument 'uplo':%c\n",uploa_c); - continue; - } + if ('l' == uploa_c || 'L' == uploa_c) + uploa = BLIS_LOWER; + else if ('u' == uploa_c || 'U' == uploa_c) + uploa = BLIS_UPPER; + else + { + printf("Invalid entry for the argument 'uplo':%c\n", uploa_c); + continue; + } - if('t' == transa_c || 'T' == transa_c) - transa = BLIS_TRANSPOSE; - else if('n' == transa_c || 'N' == transa_c) - transa = BLIS_NO_TRANSPOSE; - else - { - printf("Invalid entry for the argument 'transa':%c\n",transa_c); - continue; - } - - if('u' == diaga_c || 'U' == diaga_c) - diaga = BLIS_UNIT_DIAG; - else if('n' == diaga_c || 'N' == diaga_c) - diaga = BLIS_NONUNIT_DIAG; - else - { - printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); - continue; - } + if ('t' == transa_c || 'T' == transa_c) + transa = BLIS_TRANSPOSE; + else if ('n' == transa_c || 'N' == transa_c) + transa = BLIS_NO_TRANSPOSE; + else + { + printf("Invalid entry for the argument 'transa':%c\n", transa_c); + continue; + } + + if ('u' == diaga_c || 'U' == diaga_c) + diaga = BLIS_UNIT_DIAG; + else if ('n' == diaga_c || 'N' == diaga_c) + diaga = BLIS_NONUNIT_DIAG; + else + { + printf("Invalid entry for the argument 'diaga':%c\n", diaga_c); + continue; + } #else - + fprintf(fout, "m\t n\t cs_a\t cs_b\t gflops\n"); printf("~~~~~~~_BLAS\t m\t n\t cs_a\t cs_b\t gflops\n"); - while(fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) + while (fscanf(fin, "%ld %ld %ld %ld\n", &m, &n, &cs_a, &cs_b) == 4) { - - side = BLIS_LEFT; - //side = BLIS_RIGHT; - uploa = BLIS_LOWER; - //uploa = BLIS_UPPER; + side = BLIS_LEFT; + // side = BLIS_RIGHT; - transa = BLIS_NO_TRANSPOSE; + uploa = BLIS_LOWER; + // uploa = BLIS_UPPER; - diaga = BLIS_NONUNIT_DIAG; + transa = BLIS_NO_TRANSPOSE; + diaga = BLIS_NONUNIT_DIAG; #endif - bli_param_map_blis_to_netlib_side( side, &f77_side ); - bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + bli_param_map_blis_to_netlib_side(side, &f77_side); + bli_param_map_blis_to_netlib_uplo(uploa, &f77_uploa); + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_diag(diaga, &f77_diaga); + siz_t elem_size = bli_dt_size(dt); - siz_t elem_size = bli_dt_size( dt ); + cs_a = bli_align_dim_to_size(cs_a, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); + cs_b = bli_align_dim_to_size(cs_b, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE); - cs_a = bli_align_dim_to_size( cs_a, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); - cs_b = bli_align_dim_to_size( cs_b, elem_size, BLIS_HEAP_STRIDE_ALIGN_SIZE ); + // Will verify the leading dimension is powers of 2 and add 64bytes. + inc_t n_bytes = cs_a * sizeof(dt); - //Will verify the leading dimension is powers of 2 and add 64bytes. - inc_t n_bytes = cs_a*sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + cs_a += CACHE_LINE_SIZE / sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - cs_a += CACHE_LINE_SIZE/sizeof(dt); + n_bytes = cs_b * sizeof(dt); + if ((n_bytes != 0) && !(n_bytes & (n_bytes - 1))) // check whether n_bytes is power of 2. + cs_b += CACHE_LINE_SIZE / sizeof(dt); - n_bytes = cs_b*sizeof(dt); - if((n_bytes!=0) && !(n_bytes&(n_bytes-1)))// check whether n_bytes is power of 2. - cs_b += CACHE_LINE_SIZE/sizeof(dt); + if (bli_is_left(side) && ((m > cs_a) || (m > cs_b))) + continue; // leading dimension should be greater than number of rows + if (bli_is_right(side) && ((n > cs_a) || (m > cs_b))) + continue; // leading dimension should be greater than number of rows - if(bli_is_left(side) && ((m > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows - - if(bli_is_right(side) && ((n > cs_a) || (m > cs_b))) continue; //leading dimension should be greater than number of rows - - if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, 1, m, &a ); + if (bli_is_left(side)) + bli_obj_create(dt, m, m, 1, m, &a); else - bli_obj_create( dt, n, n, 1, n, &a ); - bli_obj_create( dt, m, n, 1, m, &c ); - bli_obj_create( dt, m, n, 1, m, &c_save ); + bli_obj_create(dt, n, n, 1, n, &a); + bli_obj_create(dt, m, n, 1, m, &c); + bli_obj_create(dt, m, n, 1, m, &c_save); #else - for ( p = p_end; p >= p_begin; p -= p_inc ) + for (p = p_end; p >= p_begin; p -= p_inc) { - if ( m_input < 0 ) m = p * ( dim_t )abs(m_input); - else m = ( dim_t ) m_input; - if ( n_input < 0 ) n = p * ( dim_t )abs(n_input); - else n = ( dim_t ) n_input; + if (m_input < 0) + m = p * (dim_t)abs(m_input); + else + m = (dim_t)m_input; + if (n_input < 0) + n = p * (dim_t)abs(n_input); + else + n = (dim_t)n_input; - - side = BLIS_LEFT; - //side = BLIS_RIGHT; + side = BLIS_LEFT; + // side = BLIS_RIGHT; - uploa = BLIS_LOWER; - //uploa = BLIS_UPPER; + uploa = BLIS_LOWER; + // uploa = BLIS_UPPER; - transa = BLIS_NO_TRANSPOSE; + transa = BLIS_NO_TRANSPOSE; - diaga = BLIS_NONUNIT_DIAG; + diaga = BLIS_NONUNIT_DIAG; - bli_param_map_blis_to_netlib_side( side, &f77_side ); - bli_param_map_blis_to_netlib_uplo( uploa, &f77_uploa ); - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); - bli_param_map_blis_to_netlib_diag( diaga, &f77_diaga ); + bli_param_map_blis_to_netlib_side(side, &f77_side); + bli_param_map_blis_to_netlib_uplo(uploa, &f77_uploa); + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); + bli_param_map_blis_to_netlib_diag(diaga, &f77_diaga); - if ( bli_is_left( side ) ) - bli_obj_create( dt, m, m, 0, 0, &a ); + if (bli_is_left(side)) + bli_obj_create(dt, m, m, 0, 0, &a); else - bli_obj_create( dt, n, n, 0, 0, &a ); - bli_obj_create( dt, m, n, 0, 0, &c ); - bli_obj_create( dt, m, n, 0, 0, &c_save ); + bli_obj_create(dt, n, n, 0, 0, &a); + bli_obj_create(dt, m, n, 0, 0, &c); + bli_obj_create(dt, m, n, 0, 0, &c_save); #endif - bli_randm( &a ); - bli_randm( &c ); + bli_randm(&a); + bli_randm(&c); - bli_obj_set_struc( BLIS_TRIANGULAR, &a ); - bli_obj_set_uplo( uploa, &a ); - bli_obj_set_conjtrans( transa, &a ); - bli_obj_set_diag( diaga, &a ); + bli_obj_set_struc(BLIS_TRIANGULAR, &a); + bli_obj_set_uplo(uploa, &a); + bli_obj_set_conjtrans(transa, &a); + bli_obj_set_diag(diaga, &a); // Randomize A and zero the unstored triangle to ensure the // implementation reads only from the stored region. - bli_randm( &a ); - bli_mktrim( &a ); + bli_randm(&a); + bli_mktrim(&a); // Load the diagonal of A to make it more likely to be invertible. - bli_shiftd( &BLIS_TWO, &a ); + bli_shiftd(&BLIS_TWO, &a); - bli_obj_create( dt, 1, 1, 0, 0, &alpha ); - bli_setsc( (2.0/1.0), 1.0, &alpha ); + bli_obj_create(dt, 1, 1, 0, 0, &alpha); + bli_setsc((2.0 / 1.0), 1.0, &alpha); + bli_copym(&c, &c_save); - bli_copym( &c, &c_save ); - dtime_save = DBL_MAX; - for ( r = 0; r < n_repeats; ++r ) + for (r = 0; r < n_repeats; ++r) { - bli_copym( &c_save, &c ); - + bli_copym(&c_save, &c); dtime = bli_clock(); - #ifdef PRINT - bli_invertd( &a ); - bli_printm( "a", &a, "%4.1f", "" ); - bli_invertd( &a ); - bli_printm( "c", &c, "%4.1f", "" ); + bli_invertd(&a); + bli_printm("a", &a, "%4.1f", ""); + bli_invertd(&a); + bli_printm("c", &c, "%4.1f", ""); #endif #ifdef BLIS - bli_trsm( side, - &alpha, - &a, - &c ); + bli_trsm(side, + &alpha, + &a, + &c); #else #ifdef CBLAS - enum CBLAS_ORDER cblas_order; - enum CBLAS_TRANSPOSE cblas_transa; - enum CBLAS_UPLO cblas_uplo; - enum CBLAS_SIDE cblas_side; - enum CBLAS_DIAG cblas_diag; - - if ( bli_obj_row_stride( &c ) == 1 ) - cblas_order = CblasColMajor; - else - cblas_order = CblasRowMajor; - - if( bli_is_trans( transa ) ) - cblas_transa = CblasTrans; - else if( bli_is_conjtrans( transa ) ) - cblas_transa = CblasConjTrans; - else - cblas_transa = CblasNoTrans; - - if(bli_is_upper(uploa)) - cblas_uplo = CblasUpper; - else - cblas_uplo = CblasLower; - - if(bli_is_left(side)) - cblas_side = CblasLeft; - else - cblas_side = CblasRight; - - if(bli_is_unit_diag(diaga)) - cblas_diag = CblasUnit; - else - cblas_diag = CblasNonUnit; + enum CBLAS_ORDER cblas_order; + enum CBLAS_TRANSPOSE cblas_transa; + enum CBLAS_UPLO cblas_uplo; + enum CBLAS_SIDE cblas_side; + enum CBLAS_DIAG cblas_diag; + + if (bli_obj_row_stride(&c) == 1) + cblas_order = CblasColMajor; + else + cblas_order = CblasRowMajor; + + if (bli_is_trans(transa)) + cblas_transa = CblasTrans; + else if (bli_is_conjtrans(transa)) + cblas_transa = CblasConjTrans; + else + cblas_transa = CblasNoTrans; + + if (bli_is_upper(uploa)) + cblas_uplo = CblasUpper; + else + cblas_uplo = CblasLower; + + if (bli_is_left(side)) + cblas_side = CblasLeft; + else + cblas_side = CblasRight; + + if (bli_is_unit_diag(diaga)) + cblas_diag = CblasUnit; + else + cblas_diag = CblasNonUnit; #else - f77_char f77_transa; - bli_param_map_blis_to_netlib_trans( transa, &f77_transa ); + f77_char f77_transa; + bli_param_map_blis_to_netlib_trans(transa, &f77_transa); #endif - if ( bli_is_float( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); + if (bli_is_float(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); - float* alphap = bli_obj_buffer( &alpha ); - float* ap = bli_obj_buffer( &a ); - float* cp = bli_obj_buffer( &c ); + float *alphap = bli_obj_buffer(&alpha); + float *ap = bli_obj_buffer(&a); + float *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_strsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - *alphap, - ap, lda, - cp, ldc - ); + cblas_strsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + *alphap, + ap, lda, + cp, ldc); #else - strsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + strsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - } - else if ( bli_is_double( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - double* alphap = bli_obj_buffer( &alpha ); - double* ap = bli_obj_buffer( &a ); - double* cp = bli_obj_buffer( &c ); + } + else if (bli_is_double(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + double *alphap = bli_obj_buffer(&alpha); + double *ap = bli_obj_buffer(&a); + double *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_dtrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - *alphap, - ap, lda, - cp, ldc - ); -#else - dtrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + cblas_dtrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + *alphap, + ap, lda, + cp, ldc); +#else + dtrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - - } - else if ( bli_is_scomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - scomplex* alphap = bli_obj_buffer( &alpha ); - scomplex* ap = bli_obj_buffer( &a ); - scomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_scomplex(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + scomplex *alphap = bli_obj_buffer(&alpha); + scomplex *ap = bli_obj_buffer(&a); + scomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_ctrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - alphap, - ap, lda, - cp, ldc - ); + cblas_ctrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + alphap, + ap, lda, + cp, ldc); #else - ctrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + ctrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - } - else if ( bli_is_dcomplex( dt ) ) - { - f77_int mm = bli_obj_length( &c ); - f77_int nn = bli_obj_width( &c ); - f77_int lda = bli_obj_col_stride( &a ); - f77_int ldc = bli_obj_col_stride( &c ); - dcomplex* alphap = bli_obj_buffer( &alpha ); - dcomplex* ap = bli_obj_buffer( &a ); - dcomplex* cp = bli_obj_buffer( &c ); + } + else if (bli_is_dcomplex(dt)) + { + f77_int mm = bli_obj_length(&c); + f77_int nn = bli_obj_width(&c); + f77_int lda = bli_obj_col_stride(&a); + f77_int ldc = bli_obj_col_stride(&c); + dcomplex *alphap = bli_obj_buffer(&alpha); + dcomplex *ap = bli_obj_buffer(&a); + dcomplex *cp = bli_obj_buffer(&c); #ifdef CBLAS - cblas_ztrsm( cblas_order, - cblas_side, - cblas_uplo, - cblas_transa, - cblas_diag, - mm, - nn, - alphap, - ap, lda, - cp, ldc - ); + cblas_ztrsm(cblas_order, + cblas_side, + cblas_uplo, + cblas_transa, + cblas_diag, + mm, + nn, + alphap, + ap, lda, + cp, ldc); #else - ztrsm_( &f77_side, - &f77_uploa, - &f77_transa, - &f77_diaga, - &mm, - &nn, - alphap, - ap, &lda, - cp, &ldc ); + ztrsm_(&f77_side, + &f77_uploa, + &f77_transa, + &f77_diaga, + &mm, + &nn, + alphap, + ap, &lda, + cp, &ldc); #endif - }else{ - printf("Invalid data type! Exiting!\n"); - exit(1); - } + } + else + { + printf("Invalid data type! Exiting!\n"); + exit(1); + } #endif - dtime_save = bli_clock_min_diff( dtime_save, dtime ); + dtime_save = bli_clock_min_diff(dtime_save, dtime); } - if ( bli_is_left( side ) ) - gflops = ( 1.0 * m * m * n ) / ( dtime_save * 1.0e9 ); + if (bli_is_left(side)) + gflops = (1.0 * m * m * n) / (dtime_save * 1.0e9); else - gflops = ( 1.0 * m * n * n ) / ( dtime_save * 1.0e9 ); + gflops = (1.0 * m * n * n) / (dtime_save * 1.0e9); - if ( bli_is_complex( dt ) ) gflops *= 4.0; + if (bli_is_complex(dt)) + gflops *= 4.0; #ifdef BLIS - printf( "data_trsm_blis" ); + printf("data_trsm_blis"); #else - printf( "data_trsm_%s", BLAS ); + printf("data_trsm_%s", BLAS); #endif #ifdef FILE_IN_OUT #ifdef READ_ALL_PARAMS_FROM_FILE - printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n",side_c, uploa_c, transa_c, diaga_c, - (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + printf("%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); - fprintf(fout,"%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, - (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + fprintf(fout, "%c\t %c\t %c\t %c\t %4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", side_c, uploa_c, transa_c, diaga_c, + (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); #else - printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); - fprintf(fout,"%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long )m, (unsigned long ) n, - (unsigned long )cs_a, (unsigned long )cs_b, - gflops); + printf("%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); + fprintf(fout, "%4lu\t %4lu\t %4lu\t %4lu\t %6.3f\n", (unsigned long)m, (unsigned long)n, + (unsigned long)cs_a, (unsigned long)cs_b, + gflops); #endif -fflush(fout); + fflush(fout); #else - printf( "( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", - ( unsigned long )(p - p_begin)/p_inc + 1, - ( unsigned long )m, - ( unsigned long )n, gflops ); + printf("( %2lu, 1:3 ) = [ %4lu %4lu %7.2f ];\n", + (unsigned long)(p - p_begin) / p_inc + 1, + (unsigned long)m, + (unsigned long)n, gflops); #endif - bli_obj_free( &alpha ); + bli_obj_free(&alpha); - bli_obj_free( &a ); - bli_obj_free( &c ); - bli_obj_free( &c_save ); + bli_obj_free(&a); + bli_obj_free(&c); + bli_obj_free(&c_save); } #ifdef FILE_IN_OUT - fclose(fin); - fclose(fout); + fclose(fin); + fclose(fout); #endif - //bli_finalize(); + // bli_finalize(); return 0; } - From 0c211363c4886e6f0a54859591ebbd58fe15b6c3 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 11 Mar 2022 00:12:52 -0600 Subject: [PATCH 43/63] Implemented optimal dotxv kernel Details: - Intrinsic implementation of zdotxv, cdotxv kernel - Unrolling in multiple of 8, remaining corner cases are handled serially for zdotxv kernel - Unrolling in multiple of 16, remainig corner cases are handled serially for cdotxv kernel - Added declaration in zen contexts AMD-Internal: [CPUPL-2050] Change-Id: Id58b0dbfdb7a782eb50eecc7142f051b630d9211 --- config/zen/bli_cntx_init_zen.c | 4 +- config/zen2/bli_cntx_init_zen2.c | 4 +- config/zen3/bli_cntx_init_zen3.c | 4 +- kernels/zen/1/bli_dotxv_zen_int.c | 499 ++++++++++++++++++++++++++++++ kernels/zen/bli_kernels_zen.h | 4 +- 5 files changed, 511 insertions(+), 4 deletions(-) diff --git a/config/zen/bli_cntx_init_zen.c b/config/zen/bli_cntx_init_zen.c index 674549d77f..3fea3ea8f9 100644 --- a/config/zen/bli_cntx_init_zen.c +++ b/config/zen/bli_cntx_init_zen.c @@ -103,7 +103,7 @@ void bli_cntx_init_zen( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -135,6 +135,8 @@ void bli_cntx_init_zen( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv #if 0 BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int, diff --git a/config/zen2/bli_cntx_init_zen2.c b/config/zen2/bli_cntx_init_zen2.c index 48cb90a4f8..1ecb62ff52 100644 --- a/config/zen2/bli_cntx_init_zen2.c +++ b/config/zen2/bli_cntx_init_zen2.c @@ -115,7 +115,7 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -142,6 +142,8 @@ void bli_cntx_init_zen2( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, diff --git a/config/zen3/bli_cntx_init_zen3.c b/config/zen3/bli_cntx_init_zen3.c index e83a12b401..02e264d277 100644 --- a/config/zen3/bli_cntx_init_zen3.c +++ b/config/zen3/bli_cntx_init_zen3.c @@ -115,7 +115,7 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // Update the context with optimized level-1v kernels. bli_cntx_set_l1v_kers ( - 24, + 26, #if 1 // amaxv BLIS_AMAXV_KER, BLIS_FLOAT, bli_samaxv_zen_int, @@ -142,6 +142,8 @@ void bli_cntx_init_zen3( cntx_t* cntx ) // dotxv BLIS_DOTXV_KER, BLIS_FLOAT, bli_sdotxv_zen_int, BLIS_DOTXV_KER, BLIS_DOUBLE, bli_ddotxv_zen_int, + BLIS_DOTXV_KER, BLIS_DCOMPLEX, bli_zdotxv_zen_int, + BLIS_DOTXV_KER, BLIS_SCOMPLEX, bli_cdotxv_zen_int, // scalv BLIS_SCALV_KER, BLIS_FLOAT, bli_sscalv_zen_int10, diff --git a/kernels/zen/1/bli_dotxv_zen_int.c b/kernels/zen/1/bli_dotxv_zen_int.c index 8ba1d1bba4..c210eceff5 100644 --- a/kernels/zen/1/bli_dotxv_zen_int.c +++ b/kernels/zen/1/bli_dotxv_zen_int.c @@ -332,3 +332,502 @@ void bli_ddotxv_zen_int PASTEMAC(d,axpys)( *alpha, rho0, *rho ); } + + +void bli_zdotxv_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + dcomplex* restrict alpha, + dcomplex* restrict x, inc_t incx, + dcomplex* restrict y, inc_t incy, + dcomplex* restrict beta, + dcomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 2; + const dim_t n_iter_unroll = 4; + + dim_t i; + dim_t n_viter; + dim_t n_left; + + dcomplex* restrict x0; + dcomplex* restrict y0; + dcomplex rho0; + + v4df_t rhov[8], xv[4], yv[8]; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + { + bli_toggle_conj( &conjx_use ); + } + // If beta is zero, initialize rho1 to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(z,eq0)( *beta ) ) + { + PASTEMAC(z,set0s)( *rho ); + } + else + { + PASTEMAC(z,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(z,eq0)( *alpha ) ) return; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + + // If there is anything that would interfere with our use of contiguous + // vector loads/stores, override n_viter and n_left to use scalar code + // for all iterations. + if ( incx != 1 || incy != 1 ) + { + n_viter = 0; + n_left = n; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm256_setzero_pd(); + rhov[1].v = _mm256_setzero_pd(); + rhov[2].v = _mm256_setzero_pd(); + rhov[3].v = _mm256_setzero_pd(); + + rhov[4].v = _mm256_setzero_pd(); + rhov[5].v = _mm256_setzero_pd(); + rhov[6].v = _mm256_setzero_pd(); + rhov[7].v = _mm256_setzero_pd(); + + if ( bli_is_conj( conjx_use ) ) + { + __m256d conju = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + yv[0].v = _mm256_mul_pd(yv[0].v, conju); + yv[1].v = _mm256_mul_pd(yv[1].v, conju); + yv[2].v = _mm256_mul_pd(yv[2].v, conju); + yv[3].v = _mm256_mul_pd(yv[3].v, conju); + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_pd( yv[0].v, 15 ); + yv[5].v = _mm256_permute_pd( yv[1].v, 15 ); + yv[6].v = _mm256_permute_pd( yv[2].v, 15 ); + yv[7].v = _mm256_permute_pd( yv[3].v, 15 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_pd( yv[0].v, 0 ); + yv[1].v = _mm256_permute_pd( yv[1].v, 0 ); + yv[2].v = _mm256_permute_pd( yv[2].v, 0 ); + yv[3].v = _mm256_permute_pd( yv[3].v, 0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + else + { + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_pd((double *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_pd((double *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_pd((double *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_pd((double *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_pd((double *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_pd((double *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_pd((double *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_pd((double *) (y0 + 3*n_elem_per_reg) ); + + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //--------------- + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_pd( yv[0].v, 15 ); + yv[5].v = _mm256_permute_pd( yv[1].v, 15 ); + yv[6].v = _mm256_permute_pd( yv[2].v, 15 ); + yv[7].v = _mm256_permute_pd( yv[3].v, 15 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //---------------- + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_pd( yv[0].v, 0 ); + yv[1].v = _mm256_permute_pd( yv[1].v, 0 ); + yv[2].v = _mm256_permute_pd( yv[2].v, 0 ); + yv[3].v = _mm256_permute_pd( yv[3].v, 0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_pd( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_pd( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_pd( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_pd( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_pd( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_pd( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_pd( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_pd( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + // - + - + + //yi0*xi0 yi0*xr0 yi1*xi1 yi1*xr1 + rhov[4].v = _mm256_permute_pd(rhov[4].v, 0x05); + rhov[5].v = _mm256_permute_pd(rhov[5].v, 0x05); + rhov[6].v = _mm256_permute_pd(rhov[6].v, 0x05); + rhov[7].v = _mm256_permute_pd(rhov[7].v, 0x05); + + rhov[0].v = _mm256_addsub_pd(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_addsub_pd(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_addsub_pd(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_addsub_pd(rhov[3].v, rhov[7].v); + + // Accumulate the unrolled rho vectors into a single vector. + rhov[0].v = _mm256_add_pd(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_pd(rhov[3].v,rhov[0].v); + + v2df_t inter1, inter2; + + inter1.v = _mm256_extractf128_pd(rhov[0].v,1); + inter2.v = _mm256_extractf128_pd(rhov[0].v,0); + + inter1.v = _mm_add_pd(inter1.v, inter2.v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter1.d[0]; + rho0.imag = inter1.d[1]; + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjx_use)) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // If there are leftover iterations, perform them with scalar code. + if ( bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(z,dotjs)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + else + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(z,dots)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + + if ( bli_is_conj( conjy ) ) + PASTEMAC(z,conjs)( rho0 ); + + // Accumulate the final result into the output variable. + PASTEMAC(z,axpys)( *alpha, rho0, *rho ); +} + +void bli_cdotxv_zen_int + ( + conj_t conjx, + conj_t conjy, + dim_t n, + scomplex* restrict alpha, + scomplex* restrict x, inc_t incx, + scomplex* restrict y, inc_t incy, + scomplex* restrict beta, + scomplex* restrict rho, + cntx_t* restrict cntx + ) +{ + const dim_t n_elem_per_reg = 4; + const dim_t n_iter_unroll = 4; + + dim_t i; + dim_t n_viter; + dim_t n_left; + + scomplex* restrict x0; + scomplex* restrict y0; + scomplex rho0; + + v8sf_t rhov[8], xv[4], yv[8]; + + conj_t conjx_use = conjx; + if ( bli_is_conj( conjy ) ) + { + bli_toggle_conj( &conjx_use ); + } + // If beta is zero, initialize rho1 to zero instead of scaling + // rho by beta (in case rho contains NaN or Inf). + if ( PASTEMAC(c,eq0)( *beta ) ) + { + PASTEMAC(c,set0s)( *rho ); + } + else + { + PASTEMAC(c,scals)( *beta, *rho ); + } + + // If the vector dimension is zero, output rho and return early. + if ( bli_zero_dim1( n ) || PASTEMAC(c,eq0)( *alpha ) ) return; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + n_viter = ( n ) / ( n_elem_per_reg * n_iter_unroll ); + n_left = ( n ) % ( n_elem_per_reg * n_iter_unroll ); + + // If there is anything that would interfere with our use of contiguous + // vector loads/stores, override n_viter and n_left to use scalar code + // for all iterations. + if ( incx != 1 || incy != 1 ) + { + n_viter = 0; + n_left = n; + } + + // Initialize local pointers. + x0 = x; + y0 = y; + + // Initialize the unrolled iterations' rho vectors to zero. + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + rhov[4].v = _mm256_setzero_ps(); + rhov[5].v = _mm256_setzero_ps(); + rhov[6].v = _mm256_setzero_ps(); + rhov[7].v = _mm256_setzero_ps(); + + if ( bli_is_conj( conjx_use ) ) + { + __m256 conju = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_ps((float *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_ps((float *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_ps((float *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_ps((float *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_ps((float *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_ps((float *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_ps((float *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_ps((float *) (y0 + 3*n_elem_per_reg) ); + + yv[0].v = _mm256_mul_ps(yv[0].v, conju); + yv[1].v = _mm256_mul_ps(yv[1].v, conju); + yv[2].v = _mm256_mul_ps(yv[2].v, conju); + yv[3].v = _mm256_mul_ps(yv[3].v, conju); + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_ps( yv[0].v, 0xf5 ); + yv[5].v = _mm256_permute_ps( yv[1].v, 0xf5 ); + yv[6].v = _mm256_permute_ps( yv[2].v, 0xf5 ); + yv[7].v = _mm256_permute_ps( yv[3].v, 0xf5 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //after permute of vector registers + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_ps( yv[0].v, 0xa0 ); + yv[1].v = _mm256_permute_ps( yv[1].v, 0xa0 ); + yv[2].v = _mm256_permute_ps( yv[2].v, 0xa0 ); + yv[3].v = _mm256_permute_ps( yv[3].v, 0xa0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_ps( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + else + { + for ( i = 0; i < n_viter; ++i ) + { + // Load the x and y input vector elements. + xv[0].v = _mm256_loadu_ps((float *) (x0 + 0*n_elem_per_reg) ); + yv[0].v = _mm256_loadu_ps((float *) (y0 + 0*n_elem_per_reg) ); + + xv[1].v = _mm256_loadu_ps((float *) (x0 + 1*n_elem_per_reg) ); + yv[1].v = _mm256_loadu_ps((float *) (y0 + 1*n_elem_per_reg) ); + + xv[2].v = _mm256_loadu_ps((float *) (x0 + 2*n_elem_per_reg) ); + yv[2].v = _mm256_loadu_ps((float *) (y0 + 2*n_elem_per_reg) ); + + xv[3].v = _mm256_loadu_ps((float *) (x0 + 3*n_elem_per_reg) ); + yv[3].v = _mm256_loadu_ps((float *) (y0 + 3*n_elem_per_reg) ); + + //yi0 yi0 yi1 yi1 + //xr0 xi0 xr1 xi1 + //--------------- + //yi0*xr0 yi0*xi0 yi1*xr1 yi1*xi1 + yv[4].v = _mm256_permute_ps( yv[0].v, 0xf5 ); + yv[5].v = _mm256_permute_ps( yv[1].v, 0xf5 ); + yv[6].v = _mm256_permute_ps( yv[2].v, 0xf5 ); + yv[7].v = _mm256_permute_ps( yv[3].v, 0xf5 ); + + //yr0 yr0 yr1 yr1 + //xr0 xi0 xr1 xi1 + //---------------- + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + yv[0].v = _mm256_permute_ps( yv[0].v, 0xa0 ); + yv[1].v = _mm256_permute_ps( yv[1].v, 0xa0 ); + yv[2].v = _mm256_permute_ps( yv[2].v, 0xa0 ); + yv[3].v = _mm256_permute_ps( yv[3].v, 0xa0 ); + + // Compute the element-wise product of the x and y vectors, + // storing in the corresponding rho vectors. + rhov[0].v = _mm256_fmadd_ps( xv[0].v, yv[0].v, rhov[0].v ); + rhov[1].v = _mm256_fmadd_ps( xv[1].v, yv[1].v, rhov[1].v ); + rhov[2].v = _mm256_fmadd_ps( xv[2].v, yv[2].v, rhov[2].v ); + rhov[3].v = _mm256_fmadd_ps( xv[3].v, yv[3].v, rhov[3].v ); + + rhov[4].v = _mm256_fmadd_ps( xv[0].v, yv[4].v, rhov[4].v ); + rhov[5].v = _mm256_fmadd_ps( xv[1].v, yv[5].v, rhov[5].v ); + rhov[6].v = _mm256_fmadd_ps( xv[2].v, yv[6].v, rhov[6].v ); + rhov[7].v = _mm256_fmadd_ps( xv[3].v, yv[7].v, rhov[7].v ); + + x0 += ( n_elem_per_reg * n_iter_unroll ); + y0 += ( n_elem_per_reg * n_iter_unroll ); + } + } + + //yr0*xr0 yr0*xi0 yr1*xr1 yr1*xi1 + // - + - + + //yi0*xi0 yi0*xr0 yi1*xi1 yi1*xr1 + rhov[4].v = _mm256_permute_ps(rhov[4].v, 0xb1); + rhov[5].v = _mm256_permute_ps(rhov[5].v, 0xb1); + rhov[6].v = _mm256_permute_ps(rhov[6].v, 0xb1); + rhov[7].v = _mm256_permute_ps(rhov[7].v, 0xb1); + + rhov[0].v = _mm256_addsub_ps(rhov[0].v, rhov[4].v); + rhov[1].v = _mm256_addsub_ps(rhov[1].v, rhov[5].v); + rhov[2].v = _mm256_addsub_ps(rhov[2].v, rhov[6].v); + rhov[3].v = _mm256_addsub_ps(rhov[3].v, rhov[7].v); + + // Accumulate the unrolled rho vectors into a single vector. + rhov[0].v = _mm256_add_ps(rhov[1].v,rhov[0].v); + rhov[0].v = _mm256_add_ps(rhov[2].v,rhov[0].v); + rhov[0].v = _mm256_add_ps(rhov[3].v,rhov[0].v); + + v4sf_t inter1, inter2; + + inter1.v = _mm256_extractf128_ps(rhov[0].v,1); + inter2.v = _mm256_extractf128_ps(rhov[0].v,0); + + inter1.v = _mm_add_ps(inter1.v, inter2.v); + + // Accumulate the final rho vector into a single scalar result. + rho0.real = inter1.f[0] + inter1.f[2]; + rho0.imag = inter1.f[1] + inter1.f[3]; + + /* Negate sign of imaginary value when vector y is conjugate */ + if ( bli_is_conj(conjx_use)) + rho0.imag = -rho0.imag; + + // Issue vzeroupper instruction to clear upper lanes of ymm registers. + // This avoids a performance penalty caused by false dependencies when + // transitioning from from AVX to SSE instructions (which may occur + // as soon as the n_left cleanup loop below if BLIS is compiled with + // -mfpmath=sse). + _mm256_zeroupper(); + + // If there are leftover iterations, perform them with scalar code. + if ( bli_is_conj( conjx_use ) ) + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(c,dotjs)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + else + { + for ( i = 0; i < n_left; ++i ) + { + PASTEMAC(c,dots)( *x0, *y0, rho0 ); + x0 += incx; + y0 += incy; + } + } + + if ( bli_is_conj( conjy ) ) + PASTEMAC(c,conjs)( rho0 ); + + // Accumulate the final result into the output variable. + PASTEMAC(c,axpys)( *alpha, rho0, *rho ); +} diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 77d34807af..236273d82c 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020 - 21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2020 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -78,6 +78,8 @@ DOTV_KER_PROT( dcomplex, z, dotv_zen_int5 ) // dotxv (intrinsics) DOTXV_KER_PROT( float, s, dotxv_zen_int ) DOTXV_KER_PROT( double, d, dotxv_zen_int ) +DOTXV_KER_PROT( dcomplex, z, dotxv_zen_int ) +DOTXV_KER_PROT( scomplex, c, dotxv_zen_int ) // scalv (intrinsics) SCALV_KER_PROT( float, s, scalv_zen_int ) From 25d28f85a93c9a24727c6204f9ad8646d81183fa Mon Sep 17 00:00:00 2001 From: Chandrashekara K R Date: Wed, 13 Apr 2022 10:03:27 +0530 Subject: [PATCH 44/63] Added the checks to not defining the bool type for C++ code in windows to avoid redefinition build time errror. AMD-Internal: [CPUPL-2037] Change-Id: I065da9206ab06f60876324f258ee12fb9fe83f88 --- frame/include/bli_type_defs.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 1a3dea1d3d..9d45aec1ab 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -89,10 +89,14 @@ typedef unsigned long int guint_t; // -- Boolean type -- // NOTE: bool_t is no longer used and has been replaced with C99's bool type. +// Not defining the bool type for C++ code in windows platform to avoid +// duplicate definition build error. #ifdef _WIN32 +#ifndef __cplusplus #undef bool typedef gint_t bool; #endif +#endif // BLIS uses TRUE and FALSE macro constants as possible boolean values, but we // define these macros in terms of true and false, respectively, which are // defined by C99 in stdbool.h. From c22bea0c170e6fc4e781e330c07b901f816dfc4f Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 8 Apr 2022 13:19:34 +0530 Subject: [PATCH 45/63] Parallelization of dtrsm_small routine 1. Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right). 2. Fine-tuning with AOCL_DYNAMIC to achieve better performance. AMD-Internal: [CPUPL-2103] Change-Id: I6be6a2b579de7df9a3141e0d68bdf3e8a869a005 --- frame/base/bli_rntm.c | 15 +++- frame/compat/bla_trsm_amd.c | 41 ++++++++- kernels/zen/3/bli_trsm_small.c | 147 ++++++++++++++++++++++++++++++--- kernels/zen/bli_kernels_zen.h | 14 +++- 4 files changed, 201 insertions(+), 16 deletions(-) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index f8e00c6208..c15650e918 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -631,13 +631,22 @@ void bli_nthreads_optimum( else n_threads_ideal = n_threads; } - else if( family == BLIS_TRSM && bli_obj_is_double(c)) + else if( family == BLIS_TRSM && bli_obj_is_double(c) ) { dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); - if(m<=512 && n<=512) - n_threads_ideal = 4; +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; +#else + if ( (m <= 512) && (n <= 512) ) + n_threads_ideal = 4; +#endif } else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) { diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index e1a2fffafd..3b3850928a 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -395,7 +395,7 @@ void strsm_ ) { AOCL_DTL_TRACE_ENTRY(AOCL_DTL_LEVEL_INFO) - AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 'd', + AOCL_DTL_LOG_TRSM_INPUTS(AOCL_DTL_LEVEL_TRACE_1, 's', *side, *uploa,*transa, *diaga, *m, *n, (void*)alpha,*lda, *ldb); @@ -886,8 +886,45 @@ void dtrsm_ return; } } -#endif + + //bli_trsm_small_mt is performing better than native multithread + //for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + { + err_t status; + status = bli_trsm_small_mt + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } +#endif// BLIS_ENABLE_OPENMP +#endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM } + bli_trsmnat ( blis_side, diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 07077010f2..f8c0ea5911 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -3821,15 +3821,22 @@ err_t bli_trsm_small num_t dt = bli_obj_dt(a); switch(dt) { - case BLIS_DOUBLE: - case BLIS_FLOAT: - case BLIS_SCOMPLEX: - { - if(m > 1000 || n > 1000) { + case BLIS_DOUBLE: + { + bool nt = bli_thread_get_is_parallel(); + if((nt == 0) && (m > 1000 || n > 1000)) { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + case BLIS_FLOAT: + case BLIS_SCOMPLEX: + { + if(m > 1000 || n > 1000) { return BLIS_NOT_YET_IMPLEMENTED; } break; - } + } case BLIS_DCOMPLEX: { if(m > 500 || n > 500) { @@ -3886,6 +3893,126 @@ err_t bli_trsm_small return err; }; +#ifdef BLIS_ENABLE_OPENMP +/* + * Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right) + */ + +err_t bli_trsm_small_mt +( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + rntm_t rntm; + gint_t m = bli_obj_length( b ); // number of rows of matrix b + gint_t n = bli_obj_width( b ); // number of columns of Matrix b + dim_t d_mr = 8,d_nr = 6; + + num_t dt = bli_obj_dt(a); + switch(dt) + { + case BLIS_DOUBLE: + { + d_mr = 8,d_nr = 6; + break; + } + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } + + #ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if( bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); + } + #endif + + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + + if (n_threads < 0 ) n_threads = 1; + + err_t status = BLIS_SUCCESS; + _Pragma( "omp parallel num_threads(n_threads)" ) + { + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + + obj_t b_t; + dim_t start; // Each thread start Index + dim_t end; // Each thread end Index + thrinfo_t thread; + + thread.n_way = n_threads; + thread.work_id = tid; + thread.ocomm_id = tid; + + + // Compute start and end indexes of matrix partitioning for each thread + if ( bli_is_right( side ) ) + { + bli_thread_range_sub ( &thread, + m, + d_mr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + else + { + bli_thread_range_sub ( &thread, + n, + d_nr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end-start, b, &b_t); + } + + // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to + // all threads + err_t status_l = BLIS_SUCCESS; + + status_l = bli_trsm_small + ( + side, + alpha, + a, + &b_t, + NULL, + NULL + ); + // To capture the error populated from any of the threads + _Pragma( "omp critical" ) + status = (status != BLIS_NOT_YET_IMPLEMENTED)?status_l:status; + } + + return status; +}// End of function +#endif + /* * ZTRSM utilities and kernel functions */ @@ -6105,7 +6232,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -8565,7 +8692,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks @@ -10909,7 +11036,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B //pointers that point to blocks for GEMM and TRSM double *a10, *a11, *b01, *b11; @@ -12889,7 +13016,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B + double *B = bli_obj_buffer_at_off(b); //pointer to matrix B double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 236273d82c..904e6cfbbf 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -309,7 +309,7 @@ void bli_dgemm_ref_k1_nn double* c, const inc_t ldc ); - err_t bli_trsm_small +err_t bli_trsm_small ( side_t side, obj_t* alpha, @@ -319,6 +319,18 @@ void bli_dgemm_ref_k1_nn cntl_t* cntl ); +#ifdef BLIS_ENABLE_OPENMP +err_t bli_trsm_small_mt + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); +#endif + // threshold functions bool bli_cntx_gemmtsup_thresh_is_met_zen ( From ade8525c3aae35e969756f0bda48f8f2ddbce27f Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 7 Apr 2022 13:47:39 +0530 Subject: [PATCH 46/63] Added runtime control for DTL logging Feature The logs can be enabled with following two methods: -- Environment variable based control: The feature can be enabled by specifying environment variable AOCL_VERBOSE=1. -- API based control: Two API's will be added to enable/disable logging at runtime 1. AOCL_DTL_Enable_Logs() 2. AOCL_DTL_Disable_Logs() -- The API takes precedence over the environment settings. AMD-Internal: [CPUPL-2101] Change-Id: Ie71c1095496fae89226049c9b9f80b00400350d5 --- aocl_dtl/aocldtl.c | 51 ++++++++++++---- aocl_dtl/aocldtl.h | 25 ++++++++ aocl_dtl/aocldtl_blis.h | 129 +++++++++++++++++++++++++--------------- aocl_dtl/aocldtlcf.h | 20 +++++-- 4 files changed, 163 insertions(+), 62 deletions(-) diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index 6f24788aa0..f3c1658ff8 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -5,7 +5,7 @@ * These functions are invoked though macros by * end user. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *=======================================================================*/ #include "blis.h" @@ -56,6 +56,10 @@ static char *pchDTL_LOG_FILE = AOCL_DTL_LOG_FILE; /* Global file pointer for logging the results */ AOCL_FLIST_Node *gpLogFileList = NULL; + + +/* Global flag to check if logging is enabled or not */ +Bool gbIsLoggingEnabled = FALSE; #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -82,6 +86,23 @@ AOCL_FLIST_Node *gpAutoTraceFileList = NULL; void DTL_Initialize( uint32 ui32CurrentLogLevel) { + /* + * This function can be invoked multiple times either via library + * initialization function (e.g. bli_init()) or when user changes + * logging state using API. However we want it to run only once + * This flag ensure that it is executed only once. + * + * DTL can be used with many libraries hence it needs its own + * method to ensure this. + */ + + static Bool bIsDTLInitDone = FALSE; + + if (bIsDTLInitDone) + { + return; + } + /* If user selects invalid trace log level then the dafault trace log level will be AOCL_DTL_LEVEL_ALL */ if ((ui32CurrentLogLevel < 1) || (ui32CurrentLogLevel > AOCL_DTL_LEVEL_ALL)) @@ -107,15 +128,9 @@ void DTL_Initialize( #endif #if (AOCL_DTL_LOG_ENABLE || AOCL_DTL_DUMP_ENABLE) - /* Create/Open the file to log the log data */ - AOCL_FLIST_AddFile(pchDTL_LOG_FILE, &gpLogFileList, AOCL_gettid()); - - if (NULL == gpLogFileList) - { - /* Unable to open the specified file.*/ - AOCL_DEBUGPRINT("Unable to create the log file %s\n", pchDTL_LOG_FILE); - return; - } + + /* Check if DTL logging is requested via envoronment variable */ + gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", FALSE ); #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -133,6 +148,9 @@ void DTL_Initialize( /* Save Id for main thread */ gtidMainThreadID = AOCL_gettid(); + // Ensure that this function is executed only once + bIsDTLInitDone = TRUE; + } /* DTL_Initialize */ #endif @@ -193,6 +211,19 @@ void DTL_Trace( { uint8 i = 0; AOCL_FAL_FILE *pOutFile = NULL; + +#if AOCL_DTL_LOG_ENABLE + /* + * For performance reasons we check the logging state in end user + * macros, this is just an additional check in case the function + * is invoked from any other context. + */ + if (gbIsLoggingEnabled == FALSE && ui8LogType == TRACE_TYPE_LOG) + { + return; + } +#endif + uint64 u64EventTime = AOCL_getTimestamp(); dim_t u64RequestedThreadsCount = AOCL_get_requested_threads_count(); diff --git a/aocl_dtl/aocldtl.h b/aocl_dtl/aocldtl.h index 7ce81561b7..f520518e9c 100644 --- a/aocl_dtl/aocldtl.h +++ b/aocl_dtl/aocldtl.h @@ -109,6 +109,31 @@ void AOCL_DTL_start_perf_timer(void); uint64 AOCL_DTL_get_time_spent(void); +/* + * Logging of inputs can be enabled by two methods: + * + * 1. Using environment variable AOCL_VERBOSE. + * 2. APIs + * + * The API takes precedence over environment variable. + * + * The global flag is maintain in the code to track the final + * state of the logging feature. + */ +extern Bool gbIsLoggingEnabled; + +/* API to enable logging at runtime */ +#define AOCL_DTL_Enable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = TRUE; + +/* API to disable logging at runtime */ +#define AOCL_DTL_Disable_Logs() \ + /* Initialize DTL if not alredy done so */ \ + AOCL_DTL_INITIALIZE(AOCL_DTL_TRACE_LEVEL); \ + gbIsLoggingEnabled = FALSE; + /* Macro to log the Data */ #define AOCL_DTL_START_PERF_TIMER() \ AOCL_DTL_start_perf_timer() diff --git a/aocl_dtl/aocldtl_blis.h b/aocl_dtl/aocldtl_blis.h index a9ea3368f9..7b352f9d43 100755 --- a/aocl_dtl/aocldtl_blis.h +++ b/aocl_dtl/aocldtl_blis.h @@ -3,7 +3,7 @@ * * Description : BLIS library specific debug helpes. * - * Copyright (C) 2020-2021, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -385,115 +385,148 @@ void AOCL_DTL_log_trmm_sizes(int8 loglevel, #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_sizes(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GEMM_STATS(loglevel, m, n, k) \ - AOCL_DTL_log_gemm_stats(loglevel, m, n, k); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemm_stats(loglevel, m, n, k); #define AOCL_DTL_LOG_TRSM_INPUTS(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ - AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsm_sizes(loglevel, dt, side, uploa, transa, diaga, m, n, alpha, lda, ldb, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GEMMT_INPUTS(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemmt_sizes(loglevel, dt, uplo, transa, transb, n, k, alpha, lda, ldb, beta, ldc, \ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_HEMM_INPUTS(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_hemm_sizes(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_hemm_sizes(loglevel, dt_type, side, uplo, m, n, alpha, lda, ldb, beta, ldc, \ + __FILE__, __FUNCTION__, __LINE__); // Level-3 Macros #define AOCL_DTL_LOG_HERK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc)\ - AOCL_DTL_log_herk_sizes(loglevel, dt_type, transa, uploc, m, k, alpha, lda, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_herk_sizes(loglevel, dt_type, transa, uploc, m, k, alpha, lda, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER2K_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc)\ - AOCL_DTL_log_her2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_SYMM_INPUTS(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc)\ - AOCL_DTL_log_symm_sizes(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_symm_sizes(loglevel, dt_type, side, uploa, m, n, alpha, lda, ldb, beta, ldc, __FILE__,\ + __FUNCTION__, __LINE__); // Level-2 Macros #define AOCL_DTL_LOG_GEMV_INPUTS(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy) \ - AOCL_DTL_log_gemv_sizes(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_gemv_sizes(loglevel, dt_type, transa, m, n, alp, lda, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_GER_INPUTS(loglevel, dt_type, m, n, alpha, incx, incy, lda) \ - AOCL_DTL_log_ger_sizes(loglevel, dt_type, m, n, alpha, incx, incy, lda, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_ger_sizes(loglevel, dt_type, m, n, alpha, incx, incy, lda, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, lda )\ - AOCL_DTL_log_her_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy)\ - AOCL_DTL_log_symv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_symv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); // Level-1 Macros #define AOCL_DTL_LOG_COPY_INPUTS(loglevel, dt_type, n, incx, incy) \ - AOCL_DTL_log_copy_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_copy_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_SCAL_INPUTS(loglevel, dt_type, alpha, n, incx )\ - AOCL_DTL_log_scal_sizes(loglevel, dt_type, alpha, n, incx, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_scal_sizes(loglevel, dt_type, alpha, n, incx, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SWAP_INPUTS(loglevel, dt_type, n, incx, incy)\ - AOCL_DTL_log_swap_sizes(loglevel, dt_type, n, incx, incy, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_swap_sizes(loglevel, dt_type, n, incx, incy, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_NRM2_INPUTS(loglevel, dt_type, n, incx)\ - AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_nrm2_sizes(loglevel, dt_type, n, incx, __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_HEMV_INPUTS(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy) \ - AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_hemv_sizes(loglevel, dt_type, uploa, m, alpha, lda, incx, beta, incy, \ + __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_HER2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ - AOCL_DTL_log_her2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, \ - __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_her2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, \ + __FILE__, __FUNCTION__, __LINE__); // Level-1 Macros #define AOCL_DTL_LOG_AMAX_INPUTS(loglevel, dt_type, n, incx) \ - AOCL_DTL_log_amax_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_amax_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_ASUM_INPUTS(loglevel, dt_type, n, incx) \ - AOCL_DTL_log_asum_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_asum_sizes(loglevel, dt_type, n, incx, __FILE__, __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_AXPBY_INPUTS(loglevel, dt_type, n, alpha, incx, beta, incy) \ - AOCL_DTL_log_axpby_sizes(loglevel, dt_type, n, alpha, incx, beta, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_axpby_sizes(loglevel, dt_type, n, alpha, incx, beta, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_AXPY_INPUTS(loglevel, dt_type, n, alpha, incx, incy) \ - AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ - __FUNCTION__, __LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_axpy_sizes(loglevel, dt_type, n, alpha, incx, incy, __FILE__,\ + __FUNCTION__, __LINE__); #define AOCL_DTL_LOG_DOTV_INPUTS(loglevel, dt_type, n, incx, incy) \ - AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_dotv_sizes(loglevel, dt_type, n, incx, incy, __FILE__, __FUNCTION__, __LINE__); \ #define AOCL_DTL_LOG_SYR2_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, incy, lda) \ - AOCL_DTL_log_syr2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr2_sizes(loglevel, dt_type, uploa, m, alpha, incx, incy, lda, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYR2K_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta, ldc) \ - AOCL_DTL_log_syr2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta,\ - ldc, __FILE__, __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr2k_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, ldb, beta,\ + ldc, __FILE__, __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYR_INPUTS(loglevel, dt_type, uploa, m, alpha, incx, lda) \ - AOCL_DTL_log_syr_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syr_sizes(loglevel, dt_type, uploa, m, alpha, incx, lda,\ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_SYRK_INPUTS(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc) \ - AOCL_DTL_log_syrk_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_syrk_sizes(loglevel, dt_type, uploc, transa, m, k, alpha, lda, beta, ldc, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRMM_INPUTS(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb) \ - AOCL_DTL_log_trmm_sizes(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__,\ - __FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trmm_sizes(loglevel, dt_type, side, uploa, transa, diaga, m, n, alpha, lda, ldb, __FILE__,\ + __FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRMV_INPUTS(loglevel, dt_type, uploa, transa, diaga, m, lda, incx) \ - AOCL_DTL_log_trmv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trmv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ + __FILE__,__FUNCTION__,__LINE__); #define AOCL_DTL_LOG_TRSV_INPUTS(loglevel, dt_type, uploa, transa, diaga, m, lda, incx ) \ - AOCL_DTL_log_trsv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ - __FILE__,__FUNCTION__,__LINE__); + if (gbIsLoggingEnabled) \ + AOCL_DTL_log_trsv_sizes(loglevel, dt_type, uploa, transa, diaga, m, lda, incx,\ + __FILE__,__FUNCTION__,__LINE__); #else #define AOCL_DTL_LOG_GEMM_INPUTS(loglevel, dt, transa, transb, m, n, k, alpha, lda, ldb, beta, ldc) diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 4f1e923a05..9420e7d364 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -5,7 +5,7 @@ * libaray, all debug features (except auto trace) * can be enabled/disabled in this file. * - * Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. * *==================================================================*/ @@ -20,9 +20,21 @@ enable this macro by making it to 1 else 0 */ #define AOCL_DTL_DUMP_ENABLE 0 -/* Macro for logging the logs If the user wants to enable loging information he - has to enable this macro by making it to 1 else 0 */ -#define AOCL_DTL_LOG_ENABLE 0 +/* + * Logging of inputs can be enabled by two methods: + * + * 1. Using environment variable AOCL_VERBOSE. + * 2. APIs AOCL_DTL_Enable_Logs(), AOCL_DTL_Disable_Logs() + * + * The API takes precedence over environment variable. + * + * The global flag is maintain in the code to track the final + * state of the logging feature. + * + * Setting AOCL_DTL_LOG_ENABLE = 0 will disable this feature + * completely and it is not recommended. + */ +#define AOCL_DTL_LOG_ENABLE 1 /* Select the trace level till which you want to log the data */ /* By default it will log for all levels */ From 46f59e7139bdca734c941b44e560f05cac315e25 Mon Sep 17 00:00:00 2001 From: Dave Date: Fri, 22 Apr 2022 11:47:00 +0530 Subject: [PATCH 47/63] Enabled zgemm_sup path and removed sqp path - Previously zgemm computation failures were due to status variable did not have pre-defined initial value which resulted in zgemm computation to return without being computed by any kernel. Reflected same change in dgemm_ function as well. - Enabled sup zgemm as the issue is fixed with status variable with bli_zgemm_small call. -Removed calling sqp method as it is disabled Change-Id: I0f4edfd619bc4877ebfc5cb6532c26c3888f919d --- frame/compat/bla_gemm_amd.c | 48 +++++-------------------------------- 1 file changed, 6 insertions(+), 42 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 7060509de2..681869c9b8 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -560,7 +560,7 @@ void dgemm_ if( ( ( (m0 + n0 -k0) < 2000) && ((m0 + k0-n0) < 2000) && ((n0 + k0-m0) < 2000) ) || ((n0 <= 10) && (k0 <=10)) ) { - err_t status; + err_t status = BLIS_FAILURE; if (bli_is_notrans(blis_transa)) { status = bli_dgemm_small( &alphao, @@ -754,50 +754,14 @@ void zgemm_ } #endif - // The code below will be called when number of threads = 1. -#if 0//ENABLE_INDUCED_METHOD - /* 3m_sqp is optimal for certain matrix shapes. - Initial study that it works well for square sizes and sizes closer to square shape. - - * Usage of 3m_sqp is restricted to sizes, where it is found efficient compared to native, sup and other induced method. - * Further investigation is necessary to make the usage choices more generic. */ - bool sqp_on = false; - if( (m0 == n0 ) && ( n0 == k0 ) && ( m0 == 128 ) ) - { - sqp_on = true; - } - - // current range of sizes used for 3m_sqp to be expaned after evaluation. - if( ( m0 >= 4200) && ( m0 <= 4600 ) && ( ( n0 >= 326 ) || (n0 <= 1600 ) ) - && ( k0 == 1120 ) ) //to be tuned further. - { - sqp_on = true; - } - - if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) ) + err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); + if(status==BLIS_SUCCESS) { - //sqp algo is found better for n > 40 - if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } + AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) + return; } -#endif//ENABLE_INDUCED_METHOD - -// sup has been disabled. - if(0) - { - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - } // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); From 11f04ba34e231b981d08ebca9661d045bb5c7b26 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 25 Apr 2022 15:58:10 +0530 Subject: [PATCH 48/63] Updated version and copyright notice. Changed AMD-BLIS version to 3.1.2 AMD-Internal: [CPUPL-2111] Change-Id: Id8fc3fbc112f08bd5e5def646c472047352e65b5 --- LICENSE | 2 +- so_version | 2 +- version | 3 +-- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/LICENSE b/LICENSE index 0e7a6071d2..be24a09734 100644 --- a/LICENSE +++ b/LICENSE @@ -15,7 +15,7 @@ copyright info. All parties provide their portions of the code under the Copyright (C) 2018, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP -Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. +Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are diff --git a/so_version b/so_version index 8efd5969fe..77605e74c7 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -2.0 +1.2 diff --git a/version b/version index 252fb77212..ef538c2810 100644 --- a/version +++ b/version @@ -1,2 +1 @@ -3.2.0 - +3.1.2 From a95349c0c179fec3d60a370932f729b75a1ab1b1 Mon Sep 17 00:00:00 2001 From: "S, HariharaSudhan" Date: Tue, 29 Mar 2022 18:05:59 +0530 Subject: [PATCH 49/63] Multithreaded SGEMV var 1 with smart threading - Implemented an OpenMP based stand alone SGEMV kernel for row-major (var 1) for multithread scenarios - Smart threading is enabled when AOCL DYNAMIC is defined - Number of threads are decided based on the input dims using smart threading AMD-Internal: [CPUPL-1984] Change-Id: I9b191e965ba7468e95aabcce21b35a533017502e --- frame/2/gemv/bli_gemv_unf_var1_amd.c | 128 ++++++++- kernels/zen/2/bli_gemv_zen_int_4.c | 395 +++++++++++++++++++++++++++ 2 files changed, 522 insertions(+), 1 deletion(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 7228c12f75..8295f3927e 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -332,6 +332,92 @@ void bli_dgemv_unf_var1 AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_3); } +// Returns the optimal number of threads for the given input sizes and fuse factor +void bli_sgemv_var1_smart_threading + ( + dim_t m, dim_t n, + dim_t fuse, + dim_t* nt, dim_t nt_max + ) +{ + // Calculate the amount data processed per iteration + dim_t n_per_loop = n / fuse; + double data_per_iter = n_per_loop* m; + double m_n_ratio = m/n; + + // When the input value is less than the fuse factor + if(n_per_loop < 1) + { + *nt = 1; + return; + } + + // Then there are two cases one + // In m < n the thread spawning is less aggressive when compared to m > n and m = n cases + if(m_n_ratio <= 0.6) + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const double lower_boundary = 50000; + const double higher_boundary = 500000; + + if(data_per_iter < lower_boundary) + { + double coeff_x = 0.9148; + double constant = -1.6252; + // Number of threads = 0.9148 * log(x) - 1.6252 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 10.23; + float constant = -82.332; + // Number of threads = 10.23 * log(x) - 82.332 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + else + { + // Boundary units is the amount of data processed by each iteration + // This is the variable X in the equation + const float lower_boundary = 50000; + const float higher_boundary = 360000; + + if(data_per_iter < lower_boundary) + { + float coeff_x2 = -2E-09; + float coeff_x = 0.0002; + float constant = 1.0234; + // Number of threads = -2E-09*x^2 + 0.0002 * x + 1.0234 + *nt = ceil(coeff_x2 * (data_per_iter * data_per_iter) + coeff_x * data_per_iter + constant); + } + else if(data_per_iter < higher_boundary) + { + float coeff_x = 16.917; + float constant = -164.82; + // Number of threads = 16.917 * log(x) - 164.82 + *nt = ceil(coeff_x * log(data_per_iter) + constant); + } + else + { + // When the amount of data to be processed is above both of the boundaries + // The number of threads spawned will be equal to the max number of threads set + *nt = nt_max; + } + } + + // When the number of threads calculated is greater than the user provided value + // Choose the user provided value + if(*nt > nt_max) + *nt = nt_max; +} + void bli_sgemv_unf_var1 ( trans_t transa, @@ -407,7 +493,46 @@ void bli_sgemv_unf_var1 return; } - /* Query the context for the kernel function pointer and fusing factor. */ +// If both multithreading and OpenMP are enabled, GEMV will multithread +#if defined(BLIS_ENABLE_MULTITHREADING) && defined(BLIS_ENABLE_OPENMP) + dim_t nt, nt_max; + + rntm_t rnmt_obj; + + b_fuse = 4; + + // Initialize a local runtime with global settings. + bli_rntm_init_from_global( &rnmt_obj ); + + // Query the total number of threads from the rntm_t object. + nt_max = bli_rntm_num_threads( &rnmt_obj ); + + //Setting the thread count to the maximum number of threads provided + nt = nt_max; + + // Enable smart threading when AOCL dynamic is enabled + #ifdef AOCL_DYNAMIC + bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); + #endif + + // Pass the input paramaters along with the number of threads to be used + bli_multi_sgemv_4x2 + ( + conja, + conjx, + n_elem, + n_iter, + alpha, + a, cs_at, rs_at, + x, incx, + beta, + y, incy, + cntx, + nt + ); + +#else + b_fuse = 8; for ( i = 0; i < n_iter; i += f ) @@ -434,6 +559,7 @@ void bli_sgemv_unf_var1 ); } +#endif } INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) diff --git a/kernels/zen/2/bli_gemv_zen_int_4.c b/kernels/zen/2/bli_gemv_zen_int_4.c index b3c92b551c..74904605ee 100644 --- a/kernels/zen/2/bli_gemv_zen_int_4.c +++ b/kernels/zen/2/bli_gemv_zen_int_4.c @@ -35,6 +35,24 @@ #include "immintrin.h" #include "blis.h" +/* Union data structure to access AVX registers + One 256-bit AVX register holds 8 SP elements. */ +typedef union +{ + __m256 v; + float f[8] __attribute__((aligned(64))); +} v8sf_t; + + +/* Union data structure to access AVX registers +* One 128-bit AVX register holds 4 SP elements. */ +typedef union +{ + __m128 v; + float f[4] __attribute__((aligned(64))); +} v4sf_t; + + /* This implementation uses 512 bits of cache line efficiently for column stored matrix and vectors. @@ -477,3 +495,380 @@ void bli_cgemv_zen_int_4x4 } } + +/* +Function performs multithreaded GEMV for float datatype +All parameters are similar to single thread GEMV except +n_thread which specifies the number of threads to be used +*/ +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ) +{ + const dim_t b_fuse = 4; + const dim_t n_elem_per_reg = 8; + dim_t total_iteration = 0; + + // If the b_n dimension is zero, y is empty and there is no computation. + if (bli_zero_dim1(b_n)) + return; + + // If the m dimension is zero, or if alpha is zero, the computation + // simplifies to updating y. + if (bli_zero_dim1(m) || PASTEMAC(s, eq0)(*alpha)) + { + + bli_sscalv_zen_int10( + BLIS_NO_CONJUGATE, + b_n, + beta, + y, incy, + cntx); + return; + } + + // If b_n is not equal to the fusing factor, then perform the entire + // operation as a loop over dotxv. + if (b_n < b_fuse) + { + for (dim_t i = 0; i < b_n; ++i) + { + float *a1 = a + (0) * inca + (i)*lda; + float *x1 = x + (0) * incx; + float *psi1 = y + (i)*incy; + + bli_sdotxv_zen_int( + conjat, + conjx, + m, + alpha, + a1, inca, + x1, incx, + beta, + psi1, + cntx); + } + return; + } + + // Calculate the total number of multithreaded iteration + total_iteration = b_n / b_fuse; + +#pragma omp parallel for num_threads(n_threads) + for (dim_t j = 0; j < total_iteration; j++) + { + float *A1 = a + (b_fuse * j) * lda; + float *x1 = x; + float *y1 = y + (b_fuse * j) * incy; + + // Intermediate variables to hold the completed dot products + float rho0[4] = {0, 0, 0, 0}; + + // If vectorization is possible, perform them with vector + // instructions. + if (inca == 1 && incx == 1) + { + const dim_t n_iter_unroll = 2; + + // Use the unrolling factor and the number of elements per register + // to compute the number of vectorized and leftover iterations. + dim_t l, unroll_inc, m_viter[2], m_left = m; + + unroll_inc = n_elem_per_reg * n_iter_unroll; + + m_viter[0] = m_left / unroll_inc; + m_left = m_left % unroll_inc; + + m_viter[1] = m_left / n_elem_per_reg ; + m_left = m_left % n_elem_per_reg; + + // Set up pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict av[4]; + + av[0] = A1 + 0 * lda; + av[1] = A1 + 1 * lda; + av[2] = A1 + 2 * lda; + av[3] = A1 + 3 * lda; + + // Initialize b_n rho vector accumulators to zero. + v8sf_t rhov[4]; + + rhov[0].v = _mm256_setzero_ps(); + rhov[1].v = _mm256_setzero_ps(); + rhov[2].v = _mm256_setzero_ps(); + rhov[3].v = _mm256_setzero_ps(); + + v8sf_t xv[2]; + v8sf_t a_vec[8]; + + // FMA operation is broken down to mul and add + // to reduce backend stalls + for (l = 0; l < m_viter[0]; ++l) + { + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + xv[1].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[4].v = _mm256_loadu_ps(av[0] + n_elem_per_reg); + + // perform: rho?v += a?v * x0v; + a_vec[0].v = _mm256_mul_ps(a_vec[0].v, xv[0].v); + rhov[0].v = _mm256_fmadd_ps(a_vec[4].v, xv[1].v, rhov[0].v); + rhov[0].v = _mm256_add_ps(a_vec[0].v, rhov[0].v); + + a_vec[1].v = _mm256_loadu_ps(av[1]); + a_vec[5].v = _mm256_loadu_ps(av[1] + n_elem_per_reg); + + a_vec[1].v = _mm256_mul_ps(a_vec[1].v, xv[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[5].v, xv[1].v, rhov[1].v); + rhov[1].v = _mm256_add_ps(a_vec[1].v, rhov[1].v); + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[6].v = _mm256_loadu_ps(av[2] + n_elem_per_reg); + + a_vec[2].v = _mm256_mul_ps(a_vec[2].v, xv[0].v); + rhov[2].v = _mm256_fmadd_ps(a_vec[6].v, xv[1].v, rhov[2].v); + rhov[2].v = _mm256_add_ps(a_vec[2].v, rhov[2].v); + + a_vec[3].v = _mm256_loadu_ps(av[3]); + a_vec[7].v = _mm256_loadu_ps(av[3] + n_elem_per_reg); + + a_vec[3].v = _mm256_mul_ps(a_vec[3].v, xv[0].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[7].v, xv[1].v, rhov[3].v); + rhov[3].v = _mm256_add_ps(a_vec[3].v, rhov[3].v); + + av[0] += unroll_inc; + av[1] += unroll_inc; + av[2] += unroll_inc; + av[3] += unroll_inc; + } + + for (l = 0; l < m_viter[1]; ++l) + { + // Load the input values. + xv[0].v = _mm256_loadu_ps(x0); + x0 += n_elem_per_reg; + + a_vec[0].v = _mm256_loadu_ps(av[0]); + a_vec[1].v = _mm256_loadu_ps(av[1]); + + rhov[0].v = _mm256_fmadd_ps(a_vec[0].v, xv[0].v, rhov[0].v); + rhov[1].v = _mm256_fmadd_ps(a_vec[1].v, xv[0].v, rhov[1].v); + + av[0] += n_elem_per_reg; + av[1] += n_elem_per_reg; + + a_vec[2].v = _mm256_loadu_ps(av[2]); + a_vec[3].v = _mm256_loadu_ps(av[3]); + + rhov[2].v = _mm256_fmadd_ps(a_vec[2].v, xv[0].v, rhov[2].v); + rhov[3].v = _mm256_fmadd_ps(a_vec[3].v, xv[0].v, rhov[3].v); + + av[2] += n_elem_per_reg; + av[3] += n_elem_per_reg; + } + + // Sum the elements within each vector. + // Sum the elements of a given rho?v with hadd. + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[1].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[3].v); + rhov[0].v = _mm256_hadd_ps(rhov[0].v, rhov[0].v); + rhov[2].v = _mm256_hadd_ps(rhov[2].v, rhov[2].v); + + // Manually add the results from above to finish the sum. + rho0[0] = rhov[0].f[0] + rhov[0].f[4]; + rho0[1] = rhov[0].f[1] + rhov[0].f[5]; + rho0[2] = rhov[2].f[0] + rhov[2].f[4]; + rho0[3] = rhov[2].f[1] + rhov[2].f[5]; + + // If leftover elements are more than 4, perform SSE + if (m_left > 4) + { + v4sf_t xv128, a_vec128[4], rhov128[4]; + + rhov128[0].v = _mm_set1_ps(0); + rhov128[1].v = _mm_set1_ps(0); + rhov128[2].v = _mm_set1_ps(0); + rhov128[3].v = _mm_set1_ps(0); + + // Load the input values. + xv128.v = _mm_loadu_ps(x0 + 0 * n_elem_per_reg); + x0 += 4; + m_left -= 4; + + a_vec128[0].v = _mm_loadu_ps(av[0]); + a_vec128[1].v = _mm_loadu_ps(av[1]); + + // perform: rho?v += a?v * x0v; + rhov128[0].v = _mm_fmadd_ps(a_vec128[0].v, xv128.v, rhov128[0].v); + rhov128[1].v = _mm_fmadd_ps(a_vec128[1].v, xv128.v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[1].v); + rhov128[0].v = _mm_hadd_ps(rhov128[0].v, rhov128[0].v); + + a_vec128[2].v = _mm_loadu_ps(av[2]); + a_vec128[3].v = _mm_loadu_ps(av[3]); + + rhov128[2].v = _mm_fmadd_ps(a_vec128[2].v, xv128.v, rhov128[2].v); + rhov128[3].v = _mm_fmadd_ps(a_vec128[3].v, xv128.v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[3].v); + rhov128[2].v = _mm_hadd_ps(rhov128[2].v, rhov128[2].v); + + rho0[0] += rhov128[0].f[0]; + rho0[1] += rhov128[0].f[1]; + rho0[2] += rhov128[2].f[0]; + rho0[3] += rhov128[2].f[1]; + + av[0] += 4; + av[1] += 4; + av[2] += 4; + av[3] += 4; + } + + // If there are leftover iterations, perform them with scalar code. + for (l = 0; l < m_left; ++l) + { + rho0[0] += *(av[0]) * (*x0); + rho0[1] += *(av[1]) * (*x0); + rho0[2] += *(av[2]) * (*x0); + rho0[3] += *(av[3]) * (*x0); + + x0 += incx; + av[0] += inca; + av[1] += inca; + av[2] += inca; + av[3] += inca; + } + + } + else + { + // When vectorization is not possible, perform with scalar code + + // Initialize pointers for x and the b_n columns of A (rows of A^T). + float *restrict x0 = x1; + float *restrict a0 = A1 + 0 * lda; + float *restrict a1 = A1 + 1 * lda; + float *restrict a2 = A1 + 2 * lda; + float *restrict a3 = A1 + 3 * lda; + + for (dim_t l = 0; l < m; ++l) + { + const float x0c = *x0; + + const float a0c = *a0; + const float a1c = *a1; + const float a2c = *a2; + const float a3c = *a3; + + rho0[0] += a0c * x0c; + rho0[1] += a1c * x0c; + rho0[2] += a2c * x0c; + rho0[3] += a3c * x0c; + + x0 += incx; + a0 += inca; + a1 += inca; + a2 += inca; + a3 += inca; + } + } + + v4sf_t rho0v, y0v; + + rho0v.v = _mm_loadu_ps(rho0); + + // Broadcast the alpha scalar. + v4sf_t alphav; + alphav.v = _mm_broadcast_ss(alpha); + + // We know at this point that alpha is nonzero; however, beta may still + // be zero. If beta is indeed zero, we must overwrite y rather than scale + // by beta (in case y contains NaN or Inf). + if (PASTEMAC(s, eq0)(*beta)) + { + // Apply alpha to the accumulated dot product in rho: + // y := alpha * rho + y0v.v = _mm_mul_ps(alphav.v, rho0v.v); + } + else + { + // Broadcast the beta scalar. + v4sf_t betav; + betav.v = _mm_broadcast_ss(beta); + + if (incy == 0) + { + // Load y. + y0v.v = _mm_loadu_ps(y1 + 0 * n_elem_per_reg); + } + else + { + // Load y. + y0v.f[0] = *(y1 + 0 * incy); + y0v.f[1] = *(y1 + 1 * incy); + y0v.f[2] = *(y1 + 2 * incy); + y0v.f[3] = *(y1 + 3 * incy); + } + + // Apply beta to y and alpha to the accumulated dot product in rho: + // y := beta * y + alpha * rho + y0v.v = _mm_mul_ps(betav.v, y0v.v); + y0v.v = _mm_fmadd_ps(alphav.v, rho0v.v, y0v.v); + } + + // Store the output. + if (incy == 1) + { + _mm_storeu_ps((y1 + 0 * n_elem_per_reg), y0v.v); + } + else + { + // Store the output. + *(y1 + 0 * incy) = y0v.f[0]; + *(y1 + 1 * incy) = y0v.f[1]; + *(y1 + 2 * incy) = y0v.f[2]; + *(y1 + 3 * incy) = y0v.f[3]; + } + } + + // Performs the complete computation if OpenMP is not enabled + dim_t start = total_iteration * b_fuse; + dim_t new_fuse = 8, f; + + // Left over corner cases completed using fused kernel + for (dim_t i = start; i < b_n; i += f) + { + f = bli_determine_blocksize_dim_f(i, b_n, new_fuse); + + float *A1 = a + (i)*lda + (0) * inca; + float *x1 = x + (0) * incx; + float *y1 = y + (i)*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8( + conjat, + conjx, + m, + f, + alpha, + A1, inca, lda, + x1, incx, + beta, + y1, incy, + cntx); + } +} From 2ca6015aa5fe2b42d5b8cf078ee20df4b9a9d090 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 22 Mar 2022 21:17:43 +0530 Subject: [PATCH 50/63] Smart Threading for GEMM (sgemm) v1. - Cache aware factorization. Experiments shows that ic,jc factorization based on m,n gives better results compared to mu,nu on a generic data set in SUP path. Also slight adjustments in the factorizations w.r.t matrix data loads can help in improving perf further. - Moving native path inputs to SUP path. Experiments shows that in multi-threaded scenarios if the per thread data falls under SUP thresholds, taking SUP path instead of native path results in improved performance. This is the case even if the original matrix dimensions falls in native path. This is not applicable if A matrix transpose is required. - Enabling B matrix packing in SUP path. Performance improvement is observed when B matrix is packed in cases where gemm takes SUP path instead of native path based on per thread matrix dimensions. AMD-Internal: [CPUPL-659] Change-Id: I3b8fc238a0ece1ababe5d64aebab63092f7c6914 --- frame/3/CMakeLists.txt | 3 +- frame/3/bli_l3.h | 5 +- frame/3/bli_l3_smart_threading.c | 557 +++++++++++++++++++++++++++++++ frame/3/bli_l3_smart_threading.h | 68 ++++ frame/3/bli_l3_sup.c | 49 +-- frame/3/bli_l3_sup_int_amd.c | 5 +- frame/base/bli_rntm.c | 82 ++++- frame/base/bli_rntm.h | 12 +- 8 files changed, 755 insertions(+), 26 deletions(-) create mode 100644 frame/3/bli_l3_smart_threading.c create mode 100644 frame/3/bli_l3_smart_threading.h diff --git a/frame/3/CMakeLists.txt b/frame/3/CMakeLists.txt index b3aaf2c8c8..e9d7da7b8e 100644 --- a/frame/3/CMakeLists.txt +++ b/frame/3/CMakeLists.txt @@ -25,6 +25,7 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_fpa.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_oapi.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_ukr_tapi.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_smart_threading.c ) # Select AMD specific sources for AMD configurations. if(${TARGET_ARCH} STREQUAL zen OR @@ -38,7 +39,7 @@ if(${TARGET_ARCH} STREQUAL zen OR else() target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int_amd.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_l3_sup_int.c ) endif() diff --git a/frame/3/bli_l3.h b/frame/3/bli_l3.h index b64da054c9..b65edfcaac 100644 --- a/frame/3/bli_l3.h +++ b/frame/3/bli_l3.h @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020, Advanced Micro Devices, Inc. + Copyright (C) 2020-22, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,3 +98,6 @@ #include "bli_trmm3.h" #include "bli_trsm.h" #include "bli_gemmt.h" + +// Smart Threading API's. +#include "bli_l3_smart_threading.h" diff --git a/frame/3/bli_l3_smart_threading.c b/frame/3/bli_l3_smart_threading.c new file mode 100644 index 0000000000..e4b9b43e24 --- /dev/null +++ b/frame/3/bli_l3_smart_threading.c @@ -0,0 +1,557 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" +#include "bli_l3_smart_threading.h" + +#ifdef AOCL_DYNAMIC + +// Utility functions. +static inline dim_t next_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == nt) + { + return part_nt; + } + + dim_t nt_temp = part_nt + 1; + while ( ( nt_temp <= nt ) && ( ( nt % nt_temp ) != 0 ) ) + { + nt_temp++; + } + return nt_temp; +} + +static inline dim_t prev_factor + ( + const dim_t nt, + const dim_t part_nt + ) +{ + if ( part_nt == 1) + { + return part_nt; + } + + dim_t nt_temp = part_nt - 1; + while ((nt_temp >= 1) && ((nt % nt_temp) != 0)) + { + nt_temp--; + } + return nt_temp; +} +// End utility functions. + +static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +static err_t bli_gemm_ic_jc_optimum_sup_zen3 + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +static void bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + const dim_t m, + const dim_t n, + const dim_t k, + dim_t nt, + dim_t* ic, + dim_t* jc, + const dim_t MR, + const dim_t NR, + const dim_t MC, + const dim_t KC + ); + +err_t bli_check_and_transform_native_to_SUP + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + dim_t ic, + dim_t jc, + const dim_t NR, + const dim_t MC, + const dim_t KC, + cntx_t* cntx, + rntm_t* rntm + ); + +err_t bli_gemm_smart_threading_sup + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + + // Sanity check, max available threads should be atleast 4 for the + // smart threading/factorization to be meaningful. For nt < 4 the + // default ic,jc factorization holds good. + if ( ( m <= 1 ) || ( n <= 1 ) || ( k <= 1 ) || ( max_available_nt < 4 ) ) + { + return ret_val; + } + + if ( bli_is_float( dt ) ) + { + ret_val = bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + max_available_nt, cntx, rntm + ); + } + else + { + // Other data types not supported for now. + } + + if ( ret_val == BLIS_SUCCESS ) + { + // This is a workaround to ensure that auto_factor attribute of rntm_t + // is not set to TRUE inside bli_rntm_set_ways_from_rntm_sup. Also + // the nt value will be properly set to ic*jc towards the end of + // bli_rntm_set_ways_from_rntm_sup. + bli_rntm_set_num_threads_only( -1, rntm ); + } + + return ret_val; +} + +static err_t bli_gemm_ic_jc_optimum_sup_arch_dispatcher + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + + arch_t id = bli_arch_query_id(); + if ( id == BLIS_ARCH_ZEN3 ) + { + ret_val = bli_gemm_ic_jc_optimum_sup_zen3 + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + max_available_nt, cntx, rntm + ); + } + else + { + // Other architectures not supported for now. + } + + return ret_val; +} + +// open zen3 region. +#define NUM_CORES_PER_CCD_ZEN3 8 + +// Determines the optimal number of threads (nt) and corresponding work split +// (ic,jc factorization of nt) for gemm on zen3 machines. +static err_t bli_gemm_ic_jc_optimum_sup_zen3 + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_SUCCESS; + + const dim_t MR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MR, cntx ); + const dim_t NR = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NR, cntx ); + const dim_t MC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_MC, cntx ); + const dim_t NC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_NC, cntx ); + const dim_t KC = bli_cntx_get_l3_sup_blksz_def_dt( dt, BLIS_KC, cntx ); + + dim_t ic = -1; + dim_t jc = -1; + + bli_thread_partition_2x2( max_available_nt, m, n, &ic, &jc ); + + dim_t jc_per_ccd = ( NUM_CORES_PER_CCD_ZEN3 + ic - 1 ) / ic ; + dim_t b_mat_data_per_ccd = jc_per_ccd * ( n / jc ); + + // All the cores (8) on a CCD share a L3 cache and hence total data + // loaded by the cores on a CCD should be < NC to avoid L3 contention. + // In cases where it is violated, it is better to increase ic and + // reduce B data per CCD, using micro panels mu, nu for thread + // partitioning can help achieve this. Avoiding further ic,jc + // adjustment in this case. + if ( b_mat_data_per_ccd > NC ) + { + const dim_t mu = m / MR; + const dim_t nu = n / NR; + bli_thread_partition_2x2( max_available_nt, mu, nu, &ic, &jc ); + } + else + { + // Adjust the ic,jc in the best match so that m_ic and n_jc + // turns out to be more cache friendly. + bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + m, n, k, max_available_nt, &ic, &jc, MR, NR, MC, KC + ); + } + + ret_val = bli_check_and_transform_native_to_SUP + ( + dt, elem_size, is_rrr_rrc_rcr_crr, m, n, k, + ic, jc, NR, MC, KC, cntx, rntm + ); + + if ( ret_val == BLIS_SUCCESS ) + { + bli_rntm_set_ic_ways_only( ic, rntm ); + bli_rntm_set_jc_ways_only( jc, rntm ); + } + + return ret_val; +} + +// The factorization of nt into ic,jc is based on m and n values (for simplicity +// it can be assumed to be based on m:n ratio). It does not take into account +// how the matrices are loaded into cache or which matrix goes to the larger +// cache. Depending on the matrix dimensions, increasing the ic can result in +// reduced loads from main memory to L2 cache for A matrix without any impact on +// B matrix load (since B is streamed into L3, which is larger). Similary +// adjusting jc can result in B matrix panels fitting perfectly within the L1 +// cache.This function makes these adjustments on ic,jc. +static void bli_gemm_cache_heur_adjust_ic_jc_sup_zen3 + ( + const dim_t m, + const dim_t n, + const dim_t k, + dim_t nt, + dim_t* ic, + dim_t* jc, + const dim_t MR, + const dim_t NR, + const dim_t MC, + const dim_t KC + ) +{ + const dim_t m_ic = m / ( *ic ); + const dim_t n_jc = n / ( *jc ); + const int64_t cur_work_per_thread = m_ic + n_jc; + + // The next and prev factors are caluclated with respect to the current + // factor part of nt. In effect + // 1. next ic * prev jc = nt + // 2. prev ic * next jc = nt + // 3. ic * jc = nt + const dim_t next_ic = next_factor( nt, ( *ic ) ); + const dim_t prev_ic = prev_factor( nt, ( *ic ) ); + const dim_t next_jc = next_factor( nt, ( *jc ) ); + const dim_t prev_jc = prev_factor( nt, ( *jc ) ); + + const dim_t m_next_ic = m / next_ic; + const dim_t m_prev_ic = m / prev_ic; + const dim_t n_next_jc = n / next_jc; + const dim_t n_prev_jc = n / prev_jc; + const dim_t n_jc_modulo_NR = n_jc % NR; + const dim_t n_prev_jc_modulo_NR = n_prev_jc % NR; + + const int64_t next_jc_work_per_thread = n_next_jc + m_prev_ic; + const int64_t next_ic_work_per_thread = m_next_ic + n_prev_jc; + + const dim_t MCx2 = MC * 2; + const dim_t NRx4 = NR * 4; + const dim_t NRx8 = NR * 8; + + // MC will be reduced if the following mods are zero. Incrementing jc + // helps in this case. + const dim_t n_mod_256 = n % 256; + const dim_t k_mod_256 = k % 256; + + const dim_t k_factor = k / KC; + + bool can_increase_jc = FALSE; + bool can_increase_ic = FALSE; + + // jc adjustment towards next highest factor if it results in n_jc*KC + // fittting completely within l1d cache. Only done if ic prev factor + // does not move m_prev_ic out of good l2 load zone (MC). + // Performance improvement also observed when n_jc is a multiple of NR. + if ( ( ( *ic ) > 1 ) && ( ( *jc ) < nt ) ) + { + // Check whether m_prev_ic remains in good l2 load zone. + if ( ( ( ( m_ic <= MC ) && ( m_prev_ic <= MC ) ) || + ( m_ic > MC ) ) && + ( ( n_jc > NR ) && ( n_next_jc == NR ) ) ) + { + can_increase_jc = TRUE; + } + // 2x2 factorization doesnt always give equal sum partition. + else if ( next_jc_work_per_thread < cur_work_per_thread ) + { + can_increase_jc = TRUE; + } + } + + // Favor jc if both n and k are multiples of 256 ( high cache line + // replacement ). + if ( ( ( *ic ) < nt ) && ( ( *jc ) > 1) ) + { + // ic adjustment towards next highest factor if it results in + // m_next_ic <= MC. This helps in reducing number of A matrix + // loads per thread to l2 from main memory. + if ( ( m_ic > MC ) && ( m_next_ic <= MC ) && + ( m_next_ic >= MR ) && ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // ic adjustment towards next highest factor resulted in better + // performance when m is sufficiently larger than n and jc prev + // factor did not result in n_prev_jc moving out of good l2 + // load zone (n_jc < 64). + else if ( ( m > ( 5 * n ) ) && ( m_ic >= MCx2 ) && ( k_factor > 4 ) && + ( ( n_jc > NRx4 ) || + ( ( n_jc <= NRx4 ) && ( n_prev_jc <= NRx4 ) ) ) ) + { + can_increase_ic = TRUE; + } + // Performance improvement also observed when n_jc is a multiple + // of NR. + else if ( ( n_jc_modulo_NR != 0 ) && ( n_prev_jc_modulo_NR == 0 ) && + ( k_factor > 4 ) ) + { + can_increase_ic = TRUE; + } + // 2x2 factorization doesnt always give equal sum partition. + else if ( next_ic_work_per_thread <= cur_work_per_thread ) + { + can_increase_ic = TRUE; + } + } + + // Favor jc if both n and k are multiples of 256 ( high cache line + // replacement ). + if ( ( n_mod_256 == 0 ) && ( k_mod_256 == 0 ) && ( k > KC ) ) + { + if ( can_increase_ic == TRUE ) + { + can_increase_ic = FALSE; + } + else if ( can_increase_jc == FALSE ) + { + can_increase_jc = TRUE; + } + } + // If only one of either n or k is a multiple of 256, favour jc if n per + // thread is within a heuristic factor of NR. + else if ( ( ( n_mod_256 == 0 ) || ( k_mod_256 == 0 ) ) && ( k > KC ) ) + { + if ( ( can_increase_ic == TRUE ) && ( n_jc <= NRx8 ) ) + { + can_increase_ic = FALSE; + } + else if ( ( can_increase_jc == FALSE ) && ( n_next_jc <= NRx8 ) ) + { + can_increase_jc = TRUE; + } + } + + // Increasing ic factor is given a higher priority compared to jc + // since it was observed that the A matrix loads (main memory -> l2) had + // more impact on perf compared to B matrix (main memory -> l3 -> l1) + // for the sizes considered. + if ( can_increase_ic ) + { + // It is expected that the larger dimension (m or n) will be + // allocated a larger share of the thread factorization. + if ( ( ( m >= n ) && ( next_ic >= prev_jc ) ) || + ( ( m <= n ) && ( next_ic <= prev_jc ) ) ) + { + *ic = next_ic; + *jc = prev_jc; + } + } + else if ( can_increase_jc ) + { + // It is expected that the larger dimension (m or n) will be + // allocated a larger share of the thread factorization. + if ( ( ( m >= n ) && ( prev_ic >= next_jc ) ) || + ( ( m <= n ) && ( prev_ic <= next_jc ) ) ) + { + *ic = prev_ic; + *jc = next_jc; + } + } +} + +// It was observed that the SUP thresholds can be lowered and applied on a +// per thread basis in multi threaded scenarios. +err_t bli_check_and_transform_native_to_SUP + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + dim_t ic, + dim_t jc, + const dim_t NR, + const dim_t MC, + const dim_t KC, + cntx_t* cntx, + rntm_t* rntm + ) +{ + err_t ret_val = BLIS_FAILURE; + dim_t m_ic; + dim_t n_jc; + + const dim_t MT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); + const dim_t NT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); + const dim_t KT = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); + + const dim_t MT_2 = MT / 2; + const dim_t NTx4 = NT * 4; + const dim_t NRx8 = NR * 8; + + const dim_t page_size = bli_info_get_page_size(); + const dim_t page_size_b_float = page_size / ( dim_t ) elem_size; + const dim_t page_size_b_floatx2 = page_size_b_float * 2; + + // Default SUP check without considering per thread dimensions. + if ( ( k < KT ) || ( m < MT ) || ( n < NT ) ) + { + ret_val = BLIS_SUCCESS; + } + // Per thread SUP limit checking. It was observed that when k is large, + // (twice page size) moving native to SUP did not help even if m_ic or + // n_jc were within SUP limits. + else if ( ( m >= MT ) && ( n >= NT ) && ( k < page_size_b_floatx2 ) ) + { + m_ic = m / ic; + n_jc = n / jc; + // In multi-threaded scenario, it was observed that if the per + // thread m dimension(A matrix) and n dimension(B matrix) is + // within a factor of SUP limits, SUP path without packing + // resulted in gains. Along similar lines, if the B matrix is + // large enough and reuse is good, packing B matrix alone in SUP + // resulted in perf gains. + if ( ( m_ic <= MT_2 ) && ( n_jc < NTx4 ) ) + { + if ( ( k > KC ) && + ( m_ic >= MC ) && ( n_jc >= NT ) ) + { + if ( is_rrr_rrc_rcr_crr ) + { + bli_rntm_set_pack_b( 1, rntm ); + } + else + { + bli_rntm_set_pack_a( 1, rntm ); + } + } + ret_val = BLIS_SUCCESS; + } + else if ( ( n_jc < NT ) && ( m_ic <= MT ) ) + { + if ( ( k > KC ) && ( m_ic >= MC ) && ( n_jc >= NRx8 ) ) + { + if ( is_rrr_rrc_rcr_crr ) + { + bli_rntm_set_pack_b( 1, rntm ); + } + else + { + bli_rntm_set_pack_a( 1, rntm ); + } + } + ret_val = BLIS_SUCCESS; + } + else + { + ret_val = BLIS_FAILURE; + } + } + else + { + ret_val = BLIS_FAILURE; + } + + return ret_val; +} +// close zen3 region. + +#endif diff --git a/frame/3/bli_l3_smart_threading.h b/frame/3/bli_l3_smart_threading.h new file mode 100644 index 0000000000..48a0a17bb2 --- /dev/null +++ b/frame/3/bli_l3_smart_threading.h @@ -0,0 +1,68 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#ifdef AOCL_DYNAMIC + +#ifndef BLIS_L3_SMART_THREADING_H +#define BLIS_L3_SMART_THREADING_H + +// Smart threading encompasses the following multi-threading related +// optimizations: +// 1. Selection of optimal number of threads (BLIS_NUM_THREADS) based +// on matrix dimensions. +// 2. Factorization of threads along m and n dimensions (BLIS_IC_NT, +// BLIS_JC_NT) based on matrix dimensions and cache friendliness. +// 3. Transformation of native to SUP path based on the per thread matrix +// dimensions after thread factorization, given that per thread dimensions +// are within SUP limits. +// 4. Enabling packing of B alone in SUP path if native -> SUP path +// transformation happened and depending on per thread matrix dimensions. +// This function captures smart threading logic fine tuned for gemm SUP path. +// Optimal thread selection is not enabled now. +err_t bli_gemm_smart_threading_sup + ( + num_t dt, + siz_t elem_size, + const bool is_rrr_rrc_rcr_crr, + const dim_t m, + const dim_t n, + const dim_t k, + const dim_t max_available_nt, + cntx_t* cntx, + rntm_t* rntm + ); + +#endif //BLIS_L3_SMART_THREADING_H + +#endif diff --git a/frame/3/bli_l3_sup.c b/frame/3/bli_l3_sup.c index a7d7a7874a..d23df8c1e5 100644 --- a/frame/3/bli_l3_sup.c +++ b/frame/3/bli_l3_sup.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -101,6 +101,34 @@ err_t bli_gemmsup // that function assumes the context pointer is valid. if ( cntx == NULL ) cntx = bli_gks_query_cntx(); + // Initialize a local runtime with global settings if necessary. Note + // that in the case that a runtime is passed in, we make a local copy. + rntm_t rntm_l; + if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } + else { rntm_l = *rntm; rntm = &rntm_l; } + +#ifdef AOCL_DYNAMIC + // Calculating optimal nt and corresponding factorization (ic,jc) here, so + // as to determine the matrix dimensions (A - m, B - n) per thread. This + // can be used to check if dimensions per thread falls under the SUP + // threshold and potentially move some of the native path gemm to SUP path + // in multi-threaded scenario. + err_t smart_threading = bli_smart_threading_sup( a, b, c, BLIS_GEMM, rntm, cntx ); + + if ( smart_threading != BLIS_SUCCESS ) + { + thresh_func_ft func_fp; + func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx); + + // Return early if the sizes are beyond SUP thresholds + if ( !func_fp( a, b, c, cntx ) ) + { + AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, + "SUP - Sizes are beyond SUP thresholds."); + return BLIS_FAILURE; + } + } +#else thresh_func_ft func_fp; func_fp = bli_cntx_get_l3_thresh_func(BLIS_GEMM, cntx); @@ -110,26 +138,7 @@ err_t bli_gemmsup AOCL_DTL_TRACE_EXIT_ERR(AOCL_DTL_LEVEL_TRACE_2, "SUP - Sizes are beyond SUP thresholds."); return BLIS_FAILURE; } - - // Initialize a local runtime with global settings if necessary. Note - // that in the case that a runtime is passed in, we make a local copy. - rntm_t rntm_l; - if ( rntm == NULL ) { bli_rntm_init_from_global( &rntm_l ); rntm = &rntm_l; } - else { rntm_l = *rntm; rntm = &rntm_l; } - -#if 0 -const num_t dt = bli_obj_dt( c ); -const dim_t m = bli_obj_length( c ); -const dim_t n = bli_obj_width( c ); -const dim_t k = bli_obj_width_after_trans( a ); -const dim_t tm = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_MT, cntx ); -const dim_t tn = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_NT, cntx ); -const dim_t tk = bli_cntx_get_l3_sup_thresh_dt( dt, BLIS_KT, cntx ); - -printf( "dims: %d %d %d (threshs: %d %d %d)\n", - (int)m, (int)n, (int)k, (int)tm, (int)tn, (int)tk ); #endif - // We've now ruled out the following two possibilities: // - the ukernel prefers the operation as-is, and the sup thresholds are // unsatisfied. diff --git a/frame/3/bli_l3_sup_int_amd.c b/frame/3/bli_l3_sup_int_amd.c index dc2ce24d2d..e00cc54ad0 100644 --- a/frame/3/bli_l3_sup_int_amd.c +++ b/frame/3/bli_l3_sup_int_amd.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019-21, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -98,7 +98,8 @@ err_t bli_gemmsup_int // If the parallel thread factorization was automatic, we update it // with a new factorization based on the matrix dimensions in units - // of micropanels. + // of micropanels. However in case smart threading is enabled, + // auto_factor will be false. if ( auto_factor ) { // In the block-panel algorithm, the m dimension is parallelized diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index c15650e918..5908471cb2 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -681,4 +681,84 @@ void bli_nthreads_optimum( return; } + +// Calculates the optimum number of threads along with the factorization +// (ic, jc) using m, n, k dimensions. This function modifies only the local +// copy of rntm with optimum threads. Since global rntm remains unchanged the +// num_threads set by application is available in global_rntm data structure. +err_t bli_smart_threading_sup + ( + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm, + cntx_t* cntx + ) +{ + // By default smart threading should be disabled. + err_t ret_val = BLIS_FAILURE; + +#ifndef BLIS_ENABLE_MULTITHREADING + return ret_val; +#endif + + dim_t n_threads = bli_rntm_num_threads( rntm ); + + // For non-openmp based threading, n_threads could be -1. + if ( ( n_threads == -1 ) || ( n_threads == 1 ) ) return ret_val; + + dim_t ic_way = bli_rntm_ic_ways( rntm ); + dim_t jc_way = bli_rntm_jc_ways( rntm ); + + // Dont enable smart threading if the user supplied the factorization. + if( ( ic_way > 0 ) || ( jc_way > 0 ) ) return ret_val; + + // Only supporting sgemm for now. + if ( ( family == BLIS_GEMM ) && bli_obj_is_float( c ) ) + { + dim_t k = bli_obj_width_after_trans(a); + dim_t m = 0; + dim_t n = 0; + + bool trans_A_for_kernel = FALSE; + + const stor3_t stor_id = bli_obj_stor3_from_strides( c, a, b ); + const bool is_rrr_rrc_rcr_crr = ( + stor_id == BLIS_RRR || + stor_id == BLIS_RRC || + stor_id == BLIS_RCR || + stor_id == BLIS_CRR + ); + + // The A and B matrices are swapped based on the storage type in + // var1n2m. Need to account for this when determining ic and jc + // based on m and n dimensions of A and B. + if ( is_rrr_rrc_rcr_crr ) + { + m = bli_obj_length( c ); + n = bli_obj_width( c ); + trans_A_for_kernel = bli_obj_has_trans( a ); + } + else + { + m = bli_obj_width( c ); + n = bli_obj_length( c ); + trans_A_for_kernel = bli_obj_has_trans( b ); + } + + // Take default path if transpose is enabled for A matrix. + if ( trans_A_for_kernel == FALSE ) + { + // A successfull call to smart threading api implies smart + // factorization and possibly native -> SUP path conversion. + // Optimal thread selection is not supported yet. + ret_val = bli_gemm_smart_threading_sup( bli_obj_dt( c ), + bli_obj_elem_size( c ), + is_rrr_rrc_rcr_crr, m, n, k, n_threads, + cntx, rntm ); + } + } + return ret_val; +} #endif // AOCL_DYNAMIC diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index 5e8e236af6..e28463c5ab 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -6,7 +6,7 @@ Copyright (C) 2014, The University of Texas at Austin Copyright (C) 2016, Hewlett Packard Enterprise Development LP - Copyright (C) 2018 - 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2018 - 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -400,6 +400,16 @@ void bli_nthreads_optimum opid_t family, rntm_t* rntm ); + +err_t bli_smart_threading_sup + ( + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm, + cntx_t* cntx + ); #endif #endif From 8d7fe0bf314e630cf0072eb96ded7ed6124fc50e Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Thu, 28 Apr 2022 15:52:06 +0530 Subject: [PATCH 51/63] Tuned aocl dynamic for specific range in dgemm Description: 1. Decision logic to choose optimal number of threads for given input dgemm dimensions under aocl dynamic feature were retuned based on latest code. 2. Updated code in few file to avoid compilation warnings. 3. Added a min check for nt in bli_sgemv_var1_smart_threading function AMD-Internal: [ CPUPL-2100 ] Change-Id: I2bc70cc87c73505dd5d2bdafb06193f664760e02 --- bench/bench_ger.c | 9 +- frame/2/gemv/bli_gemv_unf_var1_amd.c | 7 +- frame/base/bli_rntm.c | 324 ++++++++++++++------------- kernels/zen/1f/bli_axpyf_zen_int_5.c | 1 - kernels/zen/bli_kernels_zen.h | 17 ++ 5 files changed, 195 insertions(+), 163 deletions(-) diff --git a/bench/bench_ger.c b/bench/bench_ger.c index f6e5b27f59..fb50c94265 100644 --- a/bench/bench_ger.c +++ b/bench/bench_ger.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2021-22, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -66,7 +66,6 @@ int main( int argc, char** argv ) dim_t p_inc = 0; // to keep track of number of inputs num_t dt; char dt_ch; - char stor_scheme; int r, n_repeats; double dtime; @@ -76,6 +75,10 @@ int main( int argc, char** argv ) FILE* fin = NULL; FILE* fout = NULL; +#ifdef CBLAS + char stor_scheme; +#endif + n_repeats = N_REPEAT; // This macro will get from Makefile. dt = DT; @@ -108,7 +111,9 @@ int main( int argc, char** argv ) inc_t incy; char tmp[256]; // to store function name, line no present in logs. +#ifdef CBLAS stor_scheme = 'C'; +#endif // {S,D,C,Z} {transa m n alpha incx incy lda} diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index 8295f3927e..fd399c6f84 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -412,10 +412,11 @@ void bli_sgemv_var1_smart_threading } } + // When the number of threads calculated is greater than the user provided value // Choose the user provided value - if(*nt > nt_max) - *nt = nt_max; + if(*nt > nt_max ) *nt = nt_max; + if(*nt <=0 ) *nt = 1; } void bli_sgemv_unf_var1 @@ -434,7 +435,6 @@ void bli_sgemv_unf_var1 { float* A1; - float* x1; float* y1; dim_t i; dim_t b_fuse, f; @@ -537,6 +537,7 @@ void bli_sgemv_unf_var1 for ( i = 0; i < n_iter; i += f ) { + float* x1; f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); A1 = a + (i )*rs_at + (0 )*cs_at; diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 5908471cb2..f8d48c4a2e 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -53,7 +53,7 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // or the latest value of number of threads, // if set by the Application using omp_set_num_threads(nt) API. #ifdef BLIS_ENABLE_OPENMP - dim_t n_threads = omp_get_max_threads(); + dim_t n_threads = omp_get_max_threads(); #endif // Acquire the mutex protecting global_rntm. @@ -63,7 +63,7 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // before copying into local rntm structure. This updated value will be // used in the subsequent parallel regions. #ifdef BLIS_ENABLE_OPENMP - global_rntm.num_threads = n_threads; + global_rntm.num_threads = n_threads; #endif *rntm = global_rntm; @@ -75,14 +75,14 @@ void bli_rntm_init_from_global( rntm_t* rntm ) // ----------------------------------------------------------------------------- void bli_rntm_set_ways_for_op - ( - opid_t l3_op, - side_t side, - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + opid_t l3_op, + side_t side, + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { // Set the number of ways for each loop, if needed, depending on what // kind of information is already stored in the rntm_t object. @@ -95,7 +95,7 @@ bli_rntm_print( rntm ); // Now modify the number of ways, if necessary, based on the operation. if ( l3_op == BLIS_TRMM || - l3_op == BLIS_TRSM ) + l3_op == BLIS_TRSM ) { dim_t jc = bli_rntm_jc_ways( rntm ); dim_t pc = bli_rntm_pc_ways( rntm ); @@ -169,12 +169,12 @@ bli_rntm_print( rntm ); } void bli_rntm_set_ways_from_rntm - ( - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { dim_t nt = bli_rntm_num_threads( rntm ); @@ -252,7 +252,7 @@ void bli_rntm_set_ways_from_rntm pc = 1; bli_thread_partition_2x2( nt, m*BLIS_THREAD_RATIO_M, - n*BLIS_THREAD_RATIO_N, &ic, &jc ); + n*BLIS_THREAD_RATIO_N, &ic, &jc ); for ( ir = BLIS_THREAD_MAX_IR ; ir > 1 ; ir-- ) { @@ -290,12 +290,12 @@ void bli_rntm_set_ways_from_rntm } void bli_rntm_set_ways_from_rntm_sup - ( - dim_t m, - dim_t n, - dim_t k, - rntm_t* rntm - ) + ( + dim_t m, + dim_t n, + dim_t k, + rntm_t* rntm + ) { dim_t nt = bli_rntm_num_threads( rntm ); @@ -373,9 +373,9 @@ void bli_rntm_set_ways_from_rntm_sup pc = 1; //bli_thread_partition_2x2( nt, m*BLIS_THREAD_SUP_RATIO_M, - // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); + // n*BLIS_THREAD_SUP_RATIO_N, &ic, &jc ); bli_thread_partition_2x2( nt, m, - n, &ic, &jc ); + n, &ic, &jc ); //printf( "bli_rntm_set_ways_from_rntm_sup(): jc = %d ic = %d\n", (int)jc, (int)ic ); #if 0 @@ -420,9 +420,9 @@ void bli_rntm_set_ways_from_rntm_sup } void bli_rntm_print - ( - rntm_t* rntm - ) + ( + rntm_t* rntm + ) { dim_t af = bli_rntm_auto_factor( rntm ); @@ -434,35 +434,35 @@ void bli_rntm_print dim_t jr = bli_rntm_jr_ways( rntm ); dim_t ir = bli_rntm_ir_ways( rntm ); - printf( "rntm contents nt jc pc ic jr ir\n" ); + printf( "rntm contents nt jc pc ic jr ir\n" ); printf( "autofac? %1d | %4d%4d%4d%4d%4d%4d\n", (int)af, - (int)nt, (int)jc, (int)pc, - (int)ic, (int)jr, (int)ir ); + (int)nt, (int)jc, (int)pc, + (int)ic, (int)jr, (int)ir ); } // ----------------------------------------------------------------------------- dim_t bli_rntm_calc_num_threads_in - ( - bszid_t* restrict bszid_cur, - rntm_t* restrict rntm - ) + ( + bszid_t* restrict bszid_cur, + rntm_t* restrict rntm + ) { - /* // bp algorithm: - bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop - BLIS_KC, // level 1: 4th loop + /* // bp algorithm: + bszid_t bszids[7] = { BLIS_NC, // level 0: 5th loop + BLIS_KC, // level 1: 4th loop BLIS_NO_PART, // level 2: pack B - BLIS_MC, // level 3: 3rd loop + BLIS_MC, // level 3: 3rd loop BLIS_NO_PART, // level 4: pack A - BLIS_NR, // level 5: 2nd loop - BLIS_MR, // level 6: 1st loop - BLIS_KR // level 7: ukr loop - - ... // pb algorithm: - BLIS_NR, // level 5: 2nd loop - BLIS_MR, // level 6: 1st loop - BLIS_KR // level 7: ukr loop - }; */ + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + + ... // pb algorithm: + BLIS_NR, // level 5: 2nd loop + BLIS_MR, // level 6: 1st loop + BLIS_KR // level 7: ukr loop + }; */ dim_t n_threads_in = 1; // Starting with the current element of the bszids array (pointed @@ -491,7 +491,7 @@ dim_t bli_rntm_calc_num_threads_in for ( ; *bszid_cur != BLIS_KR; bszid_cur++ ) { const bszid_t bszid = *bszid_cur; - dim_t cur_way = 1; + dim_t cur_way = 1; // We assume bszid is in {NC,KC,MC,NR,MR,KR} if it is not // BLIS_NO_PART. @@ -512,12 +512,12 @@ dim_t bli_rntm_calc_num_threads_in //application is available in global_rntm data structure. void bli_nthreads_optimum( - obj_t* a, - obj_t* b, - obj_t* c, - opid_t family, - rntm_t* rntm - ) + obj_t* a, + obj_t* b, + obj_t* c, + opid_t family, + rntm_t* rntm + ) { #ifndef BLIS_ENABLE_MULTITHREADING return; @@ -531,105 +531,112 @@ void bli_nthreads_optimum( if( family == BLIS_GEMM && bli_obj_is_double(c)) { - dim_t m = bli_obj_length(c); dim_t n = bli_obj_width(c); dim_t k = bli_obj_width_after_trans(a); - if( k >= 128) { - if(n <= 15) n_threads_ideal = 8; - else n_threads_ideal = 16; + if(n <= 15) + { + if(m < 128) n_threads_ideal = 8; + else if(m < 256) n_threads_ideal = 16; + else if(m < 512) n_threads_ideal = 32; + else n_threads_ideal = 64; + }else if (n <= 64) + { + if(m < 128) n_threads_ideal = 16; + else if(m < 256) n_threads_ideal = 32; + else n_threads_ideal = 64; + }else{ + if(m < 256) n_threads_ideal = 32; + else n_threads_ideal = 64; + } } else - { - if(m > 10000) - { - - /* if(n >= 96) n_threads_ideal = 16; */ - /* else n_threads_ideal = 8; */ - - // current logic is only limiting threads to - // less or equal to 64 - limits performance. - - // To deal with larger matrix sizes we need to use - // large number of threads to improve performance - - // Need to derive this upperTH - and - // if matrix -sizes are larger and user wants - // to use higher number of threads - that should be allowed. - - // if (n > UpperTH) n_threads_ideal = n_threads; - if (n > 200 ) n_threads_ideal = 64; - else if ( n > 120 ) n_threads_ideal = 32; - else if ( n > 40 ) n_threads_ideal = 16; - else if ( n > 10 ) n_threads_ideal = 8; - else /* if ( n <= 10) */ n_threads_ideal = 4; - } - else if( m > 1000) - { - if (n <= 10) n_threads_ideal = 4; - else if ( n <= 40 ) n_threads_ideal = 8; - else if ( n <= 120 ) n_threads_ideal = 16; - else if ( n <= 200 ) n_threads_ideal = 32; - else n_threads_ideal = 64; - - /* if(n < 15) n_threads_ideal = 4; */ - /* else n_threads_ideal = 8; */ - } - else if(m > 210) - { - if(n < 10) n_threads_ideal = 1; - else n_threads_ideal = 4; - } - else if(m > 150) - { - if(n < 15) n_threads_ideal = 1; - else n_threads_ideal = 4; - } - else if( ( m < 34) && (k < 68) && ( m < 34)) - { - n_threads_ideal = 1; - } - else - { - if(n < 20) n_threads_ideal = 1; - else n_threads_ideal = 4; - } + { + if(m > 10000) + { + // current logic is only limiting threads to + // less or equal to 64 - limits performance. + // To deal with larger matrix sizes we need to use + // large number of threads to improve performance + // Need to derive this upperTH - and + // if matrix -sizes are larger and user wants + // to use higher number of threads - that should be allowed. + + // if (n > UpperTH) n_threads_ideal = n_threads; + if (n > 200 ) n_threads_ideal = 64; + else if ( n > 120 ) n_threads_ideal = 32; + else if ( n > 40 ) n_threads_ideal = 16; + else if ( n > 10 ) n_threads_ideal = 8; + else n_threads_ideal = 4; + } + else if( m > 1000) + { + if (n <= 10) n_threads_ideal = 4; + else if ( n <= 512 ) n_threads_ideal = 8; + else if ( n <= 1024 ) n_threads_ideal = 16; + else if ( n <= 2048 ) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if(m > 210) + { + if(n < 10) n_threads_ideal = 4; + else if(n <= 512) n_threads_ideal = 8; + else if(n <= 1024) n_threads_ideal = 16; + else if(n <= 2048) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if(m > 150) + { + if(n < 10) n_threads_ideal = 2; + else if(n <= 512) n_threads_ideal = 8; + else if(n <= 1024) n_threads_ideal = 16; + else if(n <= 2048) n_threads_ideal = 32; + else n_threads_ideal = 64; + } + else if( ( m < 34) && (k < 68) && ( n < 34)) + { + n_threads_ideal = 1; + } + else + { //(m<150 && k<128) + if(n < 20) n_threads_ideal = 1; + if(n < 64) n_threads_ideal = 4; + else n_threads_ideal = 8; + } } - } else if( family == BLIS_GEMM && bli_obj_is_dcomplex(c)) - { - - dim_t m = bli_obj_length(c); - dim_t n = bli_obj_width(c); - dim_t k = bli_obj_width_after_trans(a); - - if((m<=128 || n<=128 || k<=128) && (m+n+k <= 400) ) - { - n_threads_ideal = 8; - } - else if((m<=256 || n<=256 || k<=256) && (m+n+k <= 800) ) - { - n_threads_ideal = 16; - } - } + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + dim_t k = bli_obj_width_after_trans(a); + + if((m<=128 || n<=128 || k<=128) && ((m+n+k) <= 400) ) + { + n_threads_ideal = 8; + } + else if((m<=256 || n<=256 || k<=256) && ((m+n+k) <= 800) ) + { + n_threads_ideal = 16; + } + } else if( family == BLIS_SYRK && bli_obj_is_double(c)) { - dim_t n = bli_obj_length(c); - dim_t k = bli_obj_width_after_trans(a); - - if( (( n <= 10) && ( k < 700)) || - (( n <= 20) && ( k <= 190)) || - (( n <= 40) && ( k <= 80)) || - (( n <= 50) && ( k <= 40)) || - (( n <= 60) && ( k <= 20)) - ) - n_threads_ideal = 1; - else - n_threads_ideal = n_threads; + dim_t n = bli_obj_length(c); + dim_t k = bli_obj_width_after_trans(a); + + if( (( n <= 10) && ( k < 700)) || + (( n <= 20) && ( k <= 190)) || + (( n <= 40) && ( k <= 80)) || + (( n <= 50) && ( k <= 40)) || + (( n <= 60) && ( k <= 20)) + ) + n_threads_ideal = 1; + else + n_threads_ideal = n_threads; } else if( family == BLIS_TRSM && bli_obj_is_double(c) ) { @@ -637,31 +644,34 @@ void bli_nthreads_optimum( dim_t n = bli_obj_width(c); #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - if ( (m <= 300) && (n <= 300) ) - n_threads_ideal = 8; - else if ( (m <= 400) && (n <= 400) ) - n_threads_ideal = 16; - else if ( (m <= 900) && (n <= 900) ) - n_threads_ideal = 32; + if ( (m <= 300) && (n <= 300) ) + n_threads_ideal = 8; + else if ( (m <= 400) && (n <= 400) ) + n_threads_ideal = 16; + else if ( (m <= 900) && (n <= 900) ) + n_threads_ideal = 32; #else - if ( (m <= 512) && (n <= 512) ) - n_threads_ideal = 4; + if ( (m <= 512) && (n <= 512) ) + n_threads_ideal = 4; #endif } else if( family == BLIS_TRSM && bli_obj_is_dcomplex(c)) - { - dim_t m = bli_obj_length(c); - dim_t n = bli_obj_width(c); + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); - if((m>=64) && (m<=256) && (n>=64) && (n<=256)) - n_threads_ideal = 8; - } + if((m>=64) && (m<=256) && (n>=64) && (n<=256)) + { + n_threads_ideal = 8; + } + } else if( family == BLIS_GEMMT && bli_obj_is_double(c) ) { dim_t n = bli_obj_length(c); dim_t k = bli_obj_width_after_trans(a); dim_t product = (n*k)>>4; /* product is derived based on n and k */ - // Limit the number thread for smaller sizes: + + //Limit the number thread for smaller sizes: if(product <= 346) { n_threads_ideal = 1; diff --git a/kernels/zen/1f/bli_axpyf_zen_int_5.c b/kernels/zen/1f/bli_axpyf_zen_int_5.c index 8b1f697cec..8fea5f6498 100644 --- a/kernels/zen/1f/bli_axpyf_zen_int_5.c +++ b/kernels/zen/1f/bli_axpyf_zen_int_5.c @@ -325,7 +325,6 @@ void bli_daxpyf_zen_int_5 const dim_t fuse_fac = 5; const dim_t n_elem_per_reg = 4; - const dim_t n_iter_unroll = 2; dim_t i; diff --git a/kernels/zen/bli_kernels_zen.h b/kernels/zen/bli_kernels_zen.h index 904e6cfbbf..f2edd993ce 100644 --- a/kernels/zen/bli_kernels_zen.h +++ b/kernels/zen/bli_kernels_zen.h @@ -329,6 +329,22 @@ err_t bli_trsm_small_mt cntx_t* cntx, cntl_t* cntl ); + +void bli_multi_sgemv_4x2 + ( + conj_t conjat, + conj_t conjx, + dim_t m, + dim_t b_n, + float* restrict alpha, + float* restrict a, inc_t inca, inc_t lda, + float* restrict x, inc_t incx, + float* restrict beta, + float* restrict y, inc_t incy, + cntx_t* restrict cntx, + dim_t n_threads + ); + #endif // threshold functions @@ -357,3 +373,4 @@ void bli_dnorm2fv_unb_var1 cntx_t* cntx ); #endif + From d536c33e7b940698db4b67682d8ce641cfe0e486 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 29 Apr 2022 17:13:29 +0530 Subject: [PATCH 52/63] Performance Improvement for ztrsm small sizes Details: - Handled Overflow and Underflow Vulnerabilites in ztrsm small right implementations. - Fixed failures observed in Scalapack testing. AMD-Internal: [CPUPL-2115] Change-Id: I22c1ba583e0ba14d1a4684a85fa1ca6e152e8439 --- kernels/zen/3/bli_trsm_small.c | 121 ++++++++++----------------------- 1 file changed, 35 insertions(+), 86 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index f8c0ea5911..d7192a062b 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -34922,38 +34922,21 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB { if(transa) { - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -35340,30 +35323,23 @@ BLIS_INLINE err_t bli_ztrsm_small_XAutB_XAlB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const*)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; @@ -36374,39 +36350,20 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB { if(transa) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+cs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); } else { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *) - (a11+rs_a*1 + 1)); + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); } - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); + _mm256_storeu_pd((double *)(d11_pack), ymm1); } - _mm256_storeu_pd((double *)(d11_pack), ymm1); for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { @@ -36793,30 +36750,22 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB } if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_pd((__m128d const *)(a11)); - ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm7 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0); - /*Taking denomerator multiplication of real & - * imaginary components*/ - ymm4 = _mm256_mul_pd(ymm1, ymm1); - /*Swapping real & imaginary component position for addition with - * respective components*/ - ymm6 = _mm256_permute4x64_pd(ymm4, 0xb1); - ymm4 = _mm256_add_pd(ymm4, ymm6); - /*Negating imaginary component of numerator*/ - ymm1 = _mm256_mul_pd(ymm1, ymm7); - /*Dividing numerator by denominator*/ - ymm1 = _mm256_div_pd(ymm1, ymm4); -#endif + if(transa) + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,cs_a, + d11_pack,n_remainder); + } + else + { + ztrsm_small_pack_diag_element(is_unitdiag,a11,rs_a, + d11_pack,n_remainder); + } } else { ymm1 = _mm256_broadcast_pd((__m128d const *)&ones); - } - _mm256_storeu_pd((double *)(d11_pack), ymm1); + _mm256_storeu_pd((double *)(d11_pack), ymm1); + } for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { From ec2884f33115c6d2a5528f229d6ea5d05b09c7e7 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Fri, 29 Apr 2022 10:51:55 +0530 Subject: [PATCH 53/63] Updated Zen3 architecture detection for Ryzen 5000 - Added support to detect Ryzen 5000 Desktop and APUs AMD-Internal: [CPUPL-2117] Change-Id: I312a7de1a84cf368b74ba20e58192803a9f7dace --- frame/base/bli_cpuid.c | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/frame/base/bli_cpuid.c b/frame/base/bli_cpuid.c index f4251a8c5c..d10ea1039a 100644 --- a/frame/base/bli_cpuid.c +++ b/frame/base/bli_cpuid.c @@ -286,8 +286,13 @@ bool bli_cpuid_is_zen3 // we check for all of them. const bool is_arch = - (( model <= 0x0f ) || - (0x30 <= model && model <= 0x3f )); + ( + ( model <= 0x0f ) || // EPYC and ThreadRipper + ( 0x20 <= model && model <= 0x2f ) || // Ryzen 5000 Desktop + ( 0x30 <= model && model <= 0x3f ) || // Trento + ( 0x40 <= model && model <= 0x4f ) || // RMB + ( 0x50 <= model && model <= 0x5f ) // Ryzen 5000 APU + ); if ( !is_arch ) return FALSE; From 2c4f8fd30ff906a2a289b881a0c1524a68b6d175 Mon Sep 17 00:00:00 2001 From: Nallani Bhaskar Date: Fri, 29 Apr 2022 23:29:20 +0530 Subject: [PATCH 54/63] Added AOCL Dynamic feature for dtrmm Description: 1. Tuned number of threads to achive better performance for dtrmm AMD-Internal: [ CPUPL-2100 ] Change-Id: Ib2e3df224ba76d86185721bef1837cd7855dd593 --- frame/3/trmm/CMakeLists.txt | 18 ++- frame/3/trmm/bli_trmm_front_amd.c | 206 ++++++++++++++++++++++++++++++ frame/base/bli_rntm.c | 93 ++++++++++++++ 3 files changed, 315 insertions(+), 2 deletions(-) create mode 100644 frame/3/trmm/bli_trmm_front_amd.c diff --git a/frame/3/trmm/CMakeLists.txt b/frame/3/trmm/CMakeLists.txt index 076d7d4a6b..a3845f3858 100644 --- a/frame/3/trmm/CMakeLists.txt +++ b/frame/3/trmm/CMakeLists.txt @@ -1,12 +1,26 @@ -##Copyright (C) 2020, Advanced Micro Devices, Inc. All rights reserved.## +##Copyright (C) 2020-22, Advanced Micro Devices, Inc. All rights reserved.## target_sources("${PROJECT_NAME}" PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_ll_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_lu_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_rl_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_ru_ker_var2.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_xx_ker_var2.c ) +# Select AMD specific sources for AMD configurations. +if(${TARGET_ARCH} STREQUAL zen OR +${TARGET_ARCH} STREQUAL zen2 OR +${TARGET_ARCH} STREQUAL zen3 OR +${TARGET_ARCH} STREQUAL amdzen) + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front_amd.c + ) +else() + target_sources("${PROJECT_NAME}" + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trmm_front.c + ) +endif() diff --git a/frame/3/trmm/bli_trmm_front_amd.c b/frame/3/trmm/bli_trmm_front_amd.c new file mode 100644 index 0000000000..2301b323a7 --- /dev/null +++ b/frame/3/trmm/bli_trmm_front_amd.c @@ -0,0 +1,206 @@ +/* + + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + + Copyright (C) 2014, The University of Texas at Austin + Copyright (C) 2022, Advanced Micro Devices, Inc. + + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +*/ + +#include "blis.h" + +void bli_trmm_front + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + rntm_t* rntm, + cntl_t* cntl + ) +{ + bli_init_once(); + + obj_t a_local; + obj_t b_local; + obj_t c_local; + + // Check parameters. + if ( bli_error_checking_is_enabled() ) + bli_trmm_check( side, alpha, a, b, &BLIS_ZERO, b, cntx ); + + // If alpha is zero, scale by beta and return. + if ( bli_obj_equals( alpha, &BLIS_ZERO ) ) + { + bli_scalm( alpha, b ); + return; + } + + // Alias A and B so we can tweak the objects if necessary. + bli_obj_alias_to( a, &a_local ); + bli_obj_alias_to( b, &b_local ); + bli_obj_alias_to( b, &c_local ); + + // We do not explicitly implement the cases where A is transposed. + // However, we can still handle them. Specifically, if A is marked as + // needing a transposition, we simply induce a transposition. This + // allows us to only explicitly implement the no-transpose cases. Once + // the transposition is induced, the correct algorithm will be called, + // since, for example, an algorithm over a transposed lower triangular + // matrix A moves in the same direction (forwards) as a non-transposed + // upper triangular matrix. And with the transposition induced, the + // matrix now appears to be upper triangular, so the upper triangular + // algorithm will grab the correct partitions, as if it were upper + // triangular (with no transpose) all along. + if ( bli_obj_has_trans( &a_local ) ) + { + bli_obj_induce_trans( &a_local ); + bli_obj_set_onlytrans( BLIS_NO_TRANSPOSE, &a_local ); + } + +#ifdef BLIS_DISABLE_TRMM_RIGHT + // NOTE: This case casts right-side trmm in terms of left side. This is + // necessary when the current subconfiguration uses a gemm microkernel + // that assumes that the packing kernel will have already duplicated + // (broadcast) element of B in the packed copy of B. Supporting + // duplication within the logic that packs micropanels from triangular + // matrices would be ugly, and so we simply don't support it. As a + // consequence, those subconfigurations need a way to force the triangular + // matrix to be on the left (and thus the general matrix to the on the + // right). So our solution is that in those cases, the subconfigurations + // simply #define BLIS_DISABLE_TRMM_RIGHT. + + // NOTE: This case casts right-side trmm in terms of left side. This can + // lead to the microkernel being executed on an output matrix with the + // microkernel's general stride IO case (unless the microkernel supports + // both both row and column IO cases as well). + + // NOTE: Casting right-side trmm in terms of left side reduces the number + // of macrokernels exercised to two (trmm_ll and trmm_lu). + + // If A is being multiplied from the right, transpose all operands + // so that we can perform the computation as if A were being multiplied + // from the left. + if ( bli_is_right( side ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + +#else + // NOTE: This case computes right-side trmm natively with trmm_rl and + // trmm_ru macrokernels. This code path always gives us the opportunity + // to transpose the entire operation so that the effective storage format + // of the output matrix matches the microkernel's output preference. + // Thus, from a performance perspective, this case is preferred. + + // An optimization: If C is stored by rows and the micro-kernel prefers + // contiguous columns, or if C is stored by columns and the micro-kernel + // prefers contiguous rows, transpose the entire operation to allow the + // micro-kernel to access elements of C in its preferred manner. + // NOTE: We disable the optimization for 1x1 matrices since the concept + // of row- vs. column storage breaks down. + //if ( !bli_obj_is_1x1( &c_local ) ) // NOTE: This conditional should NOT + // be enabled. See issue #342 comments. + if ( bli_cntx_l3_vir_ukr_dislikes_storage_of( &c_local, BLIS_GEMM_UKR, cntx ) ) + { + bli_toggle_side( &side ); + bli_obj_induce_trans( &a_local ); + bli_obj_induce_trans( &b_local ); + bli_obj_induce_trans( &c_local ); + } + + // If A is being multiplied from the right, swap A and B so that + // the matrix will actually be on the right. + if ( bli_is_right( side ) ) + { + bli_obj_swap( &a_local, &b_local ); + } + +#endif + + // Set each alias as the root object. + // NOTE: We MUST wait until we are done potentially swapping the objects + // before setting the root fields! + bli_obj_set_as_root( &a_local ); + bli_obj_set_as_root( &b_local ); + bli_obj_set_as_root( &c_local ); + +#ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads and update in rntm + if(bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRMM, rntm ); + } +#endif + + // Parse and interpret the contents of the rntm_t object to properly + // set the ways of parallelism for each loop, and then make any + // additional modifications necessary for the current operation. + bli_rntm_set_ways_for_op + ( + BLIS_TRMM, + side, + bli_obj_length( &c_local ), + bli_obj_width( &c_local ), + bli_obj_width( &a_local ), + rntm + ); + + // A sort of hack for communicating the desired pach schemas for A and B + // to bli_gemm_cntl_create() (via bli_l3_thread_decorator() and + // bli_l3_cntl_create_if()). This allows us to access the schemas from + // the control tree, which hopefully reduces some confusion, particularly + // in bli_packm_init(). + pack_t schema_a = bli_cntx_schema_a_block( cntx ); + pack_t schema_b = bli_cntx_schema_b_panel( cntx ); + + bli_obj_set_pack_schema( schema_a, &a_local ); + bli_obj_set_pack_schema( schema_b, &b_local ); + + // Invoke the internal back-end. + bli_l3_thread_decorator + ( + bli_gemm_int, + BLIS_TRMM, // operation family id + alpha, + &a_local, + &b_local, + &BLIS_ZERO, + &c_local, + cntx, + rntm, + cntl + ); +} + diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index f8d48c4a2e..1d6c41528c 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -682,6 +682,99 @@ void bli_nthreads_optimum( n_threads_ideal = n_threads; } } + else if( family == BLIS_TRMM && bli_obj_is_double(c)) + { + dim_t m = bli_obj_length(c); + dim_t n = bli_obj_width(c); + + if(( n <= 32) && (m <= 32)) + { + n_threads_ideal=1; + /*If Side is Left*/ + }else + { + //Left Side + if(bli_obj_is_triangular(a)) + { + if((m < 300)) + { + if (n < 1000) + { + n_threads_ideal=8; + }else if (n < 2000) + { + n_threads_ideal=16; + }else if (n < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else if(m < 600) + { + if (n < 2000) + { + n_threads_ideal=16; + }else if (n < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else + { + if(n < 1000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + } + }else//Right Side + { + if((n < 300)) + { + if (m < 1000) + { + n_threads_ideal=8; + }else if (m < 2000) + { + n_threads_ideal=16; + }else if (m < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else if(n < 600) + { + if (m < 2000) + { + n_threads_ideal=16; + }else if (m < 3000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + }else + { + if(m < 1000) + { + n_threads_ideal=32; + }else + { + n_threads_ideal=64; + } + } + } + } + } dim_t n_threads_opt = bli_min(n_threads, n_threads_ideal); From 5f437903a9cd28b9726654d9fb439dfa06c008a5 Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 5 May 2022 12:05:40 +0530 Subject: [PATCH 55/63] Fixed crash issue in TRSM on non-avx platform. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Ensured that FMA, AVX2 based kernels are called only on platforms supporting these instructions, otherwise standard ‘C’ kernels will be called. - Code cleanup for optimization and consistency AMD-Internal: [CPUPL-2126] Change-Id: I203270892b2fad2ccc9301fb55e2bae75508e050 --- frame/compat/bla_trsm_amd.c | 228 +++++++++++++++++++----------------- 1 file changed, 119 insertions(+), 109 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 3b3850928a..f479b5eac0 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -594,10 +594,11 @@ void strsm_ bli_obj_set_struc( struca, &ao ); +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx_supported() == TRUE) { -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM + if (bli_cpuid_is_avx_supported() == TRUE) + { /* bli_strsm_small is performing better existing native * implementations for [m,n]<=1000 for single thread. * In case of multithread when [m,n]<=128 sinlge thread implemenation @@ -624,8 +625,9 @@ void strsm_ return; } } -#endif } +#endif + bli_trsmnat ( blis_side, @@ -854,76 +856,72 @@ void dtrsm_ bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_struc( struca, &ao ); - + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM // This function is invoked on all architectures including ‘generic’. // Non-AVX platforms will use the kernels derived from the context. - if (bli_cpuid_is_avx_supported() == TRUE) { - -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_dtrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) - { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) - { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; - } - } - - //bli_trsm_small_mt is performing better than native multithread - //for certain sizes of m & n. -#ifdef BLIS_ENABLE_OPENMP - rntm_t rntm; - bli_rntm_init_from_global( &rntm ); - - // Query the total number of threads from the rntm_t object. - dim_t n_threads = bli_rntm_num_threads( &rntm ); - if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || - ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || - ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || - ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || - ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || - ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) + if (bli_cpuid_is_avx_supported() == TRUE) { - err_t status; - status = bli_trsm_small_mt - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); + /* bli_dtrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if ((nt == 0 && m0 <= 1000 && n0 <= 1000) || + (nt && (m0 + n0) < 320)) + { + err_t status; + status = bli_trsm_small( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } + } - if ( status == BLIS_SUCCESS ) + // bli_trsm_small_mt is performing better than native multithread + // for certain sizes of m & n. +#ifdef BLIS_ENABLE_OPENMP + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads( &rntm ); + if ( ( (n_threads > 1) && (m0 <= 1500) && (n0 <= 1500) ) || + ( (n_threads == 32) && (m0 <= 2300) && (n0 <= 2300) ) || + ( (n_threads == 16) && (m0 <= 3800) && (n0 <= 3800) ) || + ( (n_threads == 8) && (m0 <= 2800) && (n0 <= 2800) ) || + ( (n_threads == 4) && (m0 <= 2000) && (n0 <= 2000) ) || + ( (n_threads == 2) && (m0 <= 2000) && (n0 <= 2000) ) ) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + err_t status; + status = bli_trsm_small_mt( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL); + + if ( status == BLIS_SUCCESS ) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); /* Finalize BLIS. */ bli_finalize_auto(); return; + } } - } #endif// BLIS_ENABLE_OPENMP + } // bli_cpuid_is_avx_supported #endif// END of BLIS_ENABLE_SMALL_MATRIX_TRSM - } bli_trsmnat ( @@ -1217,33 +1215,38 @@ void ztrsm_ bli_obj_set_struc( struca, &ao ); #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - - if(((nt==0) && (m0<=500) && (n0<=500)) || - (nt && ((m0+n0)<128))) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + + if(((nt==0) && (m0<=500) && (n0<=500)) || + (nt && ((m0+n0)<128))) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } - } + } // bli_cpuid_is_avx_supported} #endif bli_trsmnat @@ -1535,34 +1538,41 @@ void ctrsm_ bli_obj_set_conjtrans( blis_transa, &ao ); bli_obj_set_struc( struca, &ao ); + #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM - /* bli_ztrsm_small is performing better existing native - * implementations for [m,n]<=1000 for single thread. - * In case of multithread when [m,n]<=128 sinlge thread implemenation - * is doing better than native multithread */ - bool nt = bli_thread_get_is_parallel(); - if((nt==0 && m0<=1000 && n0<=1000) || - (nt && (m0+n0)<320) ) + // This function is invoked on all architectures including ‘generic’. + // Non-AVX platforms will use the kernels derived from the context. + if (bli_cpuid_is_avx_supported() == TRUE) { - err_t status; - status = bli_trsm_small - ( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL - ); - if (status == BLIS_SUCCESS) + /* bli_ztrsm_small is performing better existing native + * implementations for [m,n]<=1000 for single thread. + * In case of multithread when [m,n]<=128 sinlge thread implemenation + * is doing better than native multithread */ + bool nt = bli_thread_get_is_parallel(); + if((nt==0 && m0<=1000 && n0<=1000) || + (nt && (m0+n0)<320) ) { - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); - /* Finalize BLIS. */ - bli_finalize_auto(); - return; + err_t status; + status = bli_trsm_small + ( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL + ); + if (status == BLIS_SUCCESS) + { + AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); + /* Finalize BLIS. */ + bli_finalize_auto(); + return; + } } - } + } // bli_cpuid_is_avx_supported #endif + bli_trsmnat ( blis_side, From 6acbc83b2bfd8fcf8c624a3b852823b516d819d0 Mon Sep 17 00:00:00 2001 From: Harsh Dave Date: Fri, 6 May 2022 06:34:15 -0500 Subject: [PATCH 56/63] Fixed scalapack xcsep failer due to cdotxv kernel. -Failure was observed in zen configuration as gcc flag safe-math-optimization was being used for reference kernel compilation. - Optmized kernels were being compiled without this gcc flag resulted in computation difference resulting in test case failure. AMD-Internal: [CPUPL-2121] Change-Id: I5d86e589cdea633220aecadbcab84d9b88b31f57 --- config/generic/make_defs.mk | 4 ++-- config/zen/make_defs.mk | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/config/generic/make_defs.mk b/config/generic/make_defs.mk index ee77b6cf0e..4ce2fac758 100644 --- a/config/generic/make_defs.mk +++ b/config/generic/make_defs.mk @@ -79,10 +79,10 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) else ifeq ($(CC_VENDOR),clang) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations -ffp-contract=fast +CRVECFLAGS := $(CKVECFLAGS) else CRVECFLAGS := $(CKVECFLAGS) endif diff --git a/config/zen/make_defs.mk b/config/zen/make_defs.mk index 08d8628bec..b4153fcbfb 100644 --- a/config/zen/make_defs.mk +++ b/config/zen/make_defs.mk @@ -68,7 +68,7 @@ endif # Flags specific to reference kernels. CROPTFLAGS := $(CKOPTFLAGS) ifeq ($(CC_VENDOR),gcc) -CRVECFLAGS := $(CKVECFLAGS) -funsafe-math-optimizations +CRVECFLAGS := $(CKVECFLAGS) else CRVECFLAGS := $(CKVECFLAGS) endif From 23a3e88657a443a38912601f74a91937e64e2003 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Tue, 10 May 2022 14:46:47 +0530 Subject: [PATCH 57/63] Default sgemv kernel to be used in single-threaded scenarios. - sgemv calls a multi-threading friendly kernel whenever it is compiled with open mp and multi-threading enabled. However it was observed that this kernel is not suited for scenarios where sgemv is invoked in a single-threaded context (eg: sgemv from ST sgemm fringe kernels and with matrix blocking). Falling back to the default single-threaded sgemv kernel resulted in better performance for this scenario. AMD-Internal: [CPUPL-2136] Change-Id: Ic023db4d20b2503ea45e56a839aa35de0337d5a6 --- frame/2/gemv/bli_gemv_unf_var1_amd.c | 97 +++++++++++++++------------- 1 file changed, 51 insertions(+), 46 deletions(-) diff --git a/frame/2/gemv/bli_gemv_unf_var1_amd.c b/frame/2/gemv/bli_gemv_unf_var1_amd.c index fd399c6f84..447f8dbc43 100644 --- a/frame/2/gemv/bli_gemv_unf_var1_amd.c +++ b/frame/2/gemv/bli_gemv_unf_var1_amd.c @@ -495,72 +495,77 @@ void bli_sgemv_unf_var1 // If both multithreading and OpenMP are enabled, GEMV will multithread #if defined(BLIS_ENABLE_MULTITHREADING) && defined(BLIS_ENABLE_OPENMP) - dim_t nt, nt_max; - - rntm_t rnmt_obj; + bool is_omp_mt_enabled = TRUE; +#else + bool is_omp_mt_enabled = FALSE; +#endif - b_fuse = 4; + dim_t nt_max; + rntm_t rnmt_obj; // Initialize a local runtime with global settings. bli_rntm_init_from_global( &rnmt_obj ); // Query the total number of threads from the rntm_t object. nt_max = bli_rntm_num_threads( &rnmt_obj ); - - //Setting the thread count to the maximum number of threads provided - nt = nt_max; - - // Enable smart threading when AOCL dynamic is enabled - #ifdef AOCL_DYNAMIC - bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); - #endif - - // Pass the input paramaters along with the number of threads to be used - bli_multi_sgemv_4x2 - ( - conja, - conjx, - n_elem, - n_iter, - alpha, - a, cs_at, rs_at, - x, incx, - beta, - y, incy, - cntx, - nt - ); - -#else - b_fuse = 8; - - for ( i = 0; i < n_iter; i += f ) + if ( ( nt_max > 1 ) & ( is_omp_mt_enabled == TRUE ) ) { - float* x1; - f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + b_fuse = 4; + + //Setting the thread count to the maximum number of threads provided + dim_t nt = nt_max; - A1 = a + (i )*rs_at + (0 )*cs_at; - x1 = x + (0 )*incy; - y1 = y + (i )*incy; + // Enable smart threading when AOCL dynamic is enabled + #ifdef AOCL_DYNAMIC + bli_sgemv_var1_smart_threading(n_elem, n_iter, b_fuse, &nt, nt_max); + #endif - /* y1 = beta * y1 + alpha * A1 * x; */ - bli_sdotxf_zen_int_8 + // Pass the input paramaters along with the number of threads to be used + bli_multi_sgemv_4x2 ( conja, conjx, n_elem, - f, + n_iter, alpha, - A1, cs_at, rs_at, - x1, incx, + a, cs_at, rs_at, + x, incx, beta, - y1, incy, - cntx + y, incy, + cntx, + nt ); + } + else + { + b_fuse = 8; + for ( i = 0; i < n_iter; i += f ) + { + float* x1; + f = bli_determine_blocksize_dim_f( i, n_iter, b_fuse ); + + A1 = a + (i )*rs_at + (0 )*cs_at; + x1 = x + (0 )*incy; + y1 = y + (i )*incy; + + /* y1 = beta * y1 + alpha * A1 * x; */ + bli_sdotxf_zen_int_8 + ( + conja, + conjx, + n_elem, + f, + alpha, + A1, cs_at, rs_at, + x1, incx, + beta, + y1, incy, + cntx + ); + } } -#endif } INSERT_GENTFUNC_BASIC0_CZ( gemv_unf_var1 ) From c35e433708d80da75397a4cc5f0bd75c453d8081 Mon Sep 17 00:00:00 2001 From: mkadavil Date: Thu, 12 May 2022 18:02:01 +0530 Subject: [PATCH 58/63] Bug fixes for open mp based multi-threaded GEMM/GEMMT SUP path. - auto_factor to be disabled if BLIS_IC_NT/BLIS_JC_NT is set irrespective of whether num_threads (BLIS_NUM_THREADS) is modified at runtime. Currently the auto_factor is enabled if num_threads > 0 and not reverted if ic/jc/pc/jr/ir ways are set in bli_rntm_set_ways_from_rntm. This results in gemm/gemmt SUP path applying 2x2 factorization of num_threads, and thereby modifying the preset factorization. This issue is not observed in native path since factorization happens without checking auto_factor value. - Setting omp threads to n_threads using omp_set_num_threads after the global_rntm n_threads update in bli_thread_set_num_threads. This ensures that in bli_rntm_init_from_global, omp_get_max_threads returns the same value as set previously. AMD-Internal: [CPUPL-2137] Change-Id: I6c5de0462c5837cfb64793c3e6d49ec3ac2b6426 --- frame/base/bli_rntm.c | 10 ++++++++++ frame/thread/bli_thread.c | 12 ++++++++++++ 2 files changed, 22 insertions(+) diff --git a/frame/base/bli_rntm.c b/frame/base/bli_rntm.c index 1d6c41528c..fbf5654b7a 100644 --- a/frame/base/bli_rntm.c +++ b/frame/base/bli_rntm.c @@ -219,6 +219,11 @@ void bli_rntm_set_ways_from_rntm if ( ic < 1 ) ic = 1; if ( jr < 1 ) jr = 1; if ( ir < 1 ) ir = 1; + + // auto factorization is to be disabled if BLIS_IC_NT/BLIS_JC_NT env + // variables are set irrespective of whether num_threads is modified + // or not. This ensures that preset factorization is prioritized. + auto_factor = FALSE; } // Now we use the values of nt_set and ways_set to determine how to @@ -340,6 +345,11 @@ void bli_rntm_set_ways_from_rntm_sup if ( ic < 1 ) ic = 1; if ( jr < 1 ) jr = 1; if ( ir < 1 ) ir = 1; + + // auto factorization is to be disabled if BLIS_IC_NT/BLIS_JC_NT env + // variables are set irrespective of whether num_threads is modified + // or not. This ensures that preset factorization is prioritized. + auto_factor = FALSE; } // Now we use the values of nt_set and ways_set to determine how to diff --git a/frame/thread/bli_thread.c b/frame/thread/bli_thread.c index f570bcc2d8..097d136e7e 100644 --- a/frame/thread/bli_thread.c +++ b/frame/thread/bli_thread.c @@ -1604,11 +1604,23 @@ void bli_thread_set_num_threads( dim_t n_threads ) // We must ensure that global_rntm has been initialized. bli_init_once(); + if ( n_threads <= 0 ) + { + n_threads = 1; + } + // Acquire the mutex protecting global_rntm. bli_pthread_mutex_lock( &global_rntm_mutex ); bli_rntm_set_num_threads_only( n_threads, &global_rntm ); +#ifdef BLIS_ENABLE_OPENMP + // In the function bli_rntm_init_from_global() we extract n_threads + // using the API omp_get_max_threads(). Following step ensures that + // omp_get_max_threads returns the same value as set here. + omp_set_num_threads( n_threads ); +#endif + // Release the mutex protecting global_rntm. bli_pthread_mutex_unlock( &global_rntm_mutex ); } From 8bdc484e5859e5502f99beaf2e3ea085ab633b3b Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Mon, 16 May 2022 15:57:12 +0530 Subject: [PATCH 59/63] Disable AOCL_VERBOSE feature - AOCL_VERBOSE implementation is causing breakage in libFLAME. Currently DTL code is duplicated in BLIS and libFLAME, Which results in duplicate symbol errors when DTL is enabled in both the libraries. - It will be addressed by making DTL as separate library. - The input logs can still be enabled by setting AOCL_DTL_LOG_ENABLE = 1 in aocldtlcf.h and recompiling the BLIS library. AMD-Internal: [CPUPL-2101] Change-Id: I8e69b68d53940e306a1d16ffbb65019def7e655a --- aocl_dtl/aocldtl.c | 4 ++-- aocl_dtl/aocldtlcf.h | 18 +++--------------- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/aocl_dtl/aocldtl.c b/aocl_dtl/aocldtl.c index f3c1658ff8..6e7ee35102 100644 --- a/aocl_dtl/aocldtl.c +++ b/aocl_dtl/aocldtl.c @@ -59,7 +59,7 @@ AOCL_FLIST_Node *gpLogFileList = NULL; /* Global flag to check if logging is enabled or not */ -Bool gbIsLoggingEnabled = FALSE; +Bool gbIsLoggingEnabled = TRUE; #endif #if AOCL_DTL_AUTO_TRACE_ENABLE @@ -130,7 +130,7 @@ void DTL_Initialize( #if (AOCL_DTL_LOG_ENABLE || AOCL_DTL_DUMP_ENABLE) /* Check if DTL logging is requested via envoronment variable */ - gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", FALSE ); + gbIsLoggingEnabled = bli_env_get_var( "AOCL_VERBOSE", TRUE ); #endif #if AOCL_DTL_AUTO_TRACE_ENABLE diff --git a/aocl_dtl/aocldtlcf.h b/aocl_dtl/aocldtlcf.h index 9420e7d364..1f44f54405 100644 --- a/aocl_dtl/aocldtlcf.h +++ b/aocl_dtl/aocldtlcf.h @@ -20,21 +20,9 @@ enable this macro by making it to 1 else 0 */ #define AOCL_DTL_DUMP_ENABLE 0 -/* - * Logging of inputs can be enabled by two methods: - * - * 1. Using environment variable AOCL_VERBOSE. - * 2. APIs AOCL_DTL_Enable_Logs(), AOCL_DTL_Disable_Logs() - * - * The API takes precedence over environment variable. - * - * The global flag is maintain in the code to track the final - * state of the logging feature. - * - * Setting AOCL_DTL_LOG_ENABLE = 0 will disable this feature - * completely and it is not recommended. - */ -#define AOCL_DTL_LOG_ENABLE 1 +/* Macro for dumping the log If the user wants to enable input logs he has to + enable this macro by making it to 1 else 0 */ +#define AOCL_DTL_LOG_ENABLE 0 /* Select the trace level till which you want to log the data */ /* By default it will log for all levels */ From 1c9d55fd5937b534dbc837b12258c9e454126128 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Wed, 18 May 2022 16:53:24 +0530 Subject: [PATCH 60/63] Disabled zgemm SUP path - Need to identify new Thresholds for zgemm SUP path to avoid performance regression. AMD-Internal: [CPUPL-2148] Change-Id: I0baa2b415dc5e296780566ba7450249445b93d43 --- frame/compat/bla_gemm_amd.c | 9 --------- 1 file changed, 9 deletions(-) diff --git a/frame/compat/bla_gemm_amd.c b/frame/compat/bla_gemm_amd.c index 681869c9b8..99d7371778 100644 --- a/frame/compat/bla_gemm_amd.c +++ b/frame/compat/bla_gemm_amd.c @@ -753,15 +753,6 @@ void zgemm_ } } #endif - - err_t status = bli_gemmsup(&alphao, &ao, &bo, &betao, &co, NULL, NULL); - if(status==BLIS_SUCCESS) - { - AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); - AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_TRACE_1) - return; - } - // fall back on native path when zgemm is not handled in sup path. bli_gemmnat(&alphao, &ao, &bo, &betao, &co, NULL, NULL); AOCL_DTL_LOG_GEMM_STATS(AOCL_DTL_LEVEL_TRACE_1, *m, *n, *k); From 45b70caadac533c10ae8fc1b9305b86fb216dd7c Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Thu, 26 May 2022 14:36:30 +0530 Subject: [PATCH 61/63] Updated BLIS version to 3.2.0 AMD-Internal: [CPUPL-2161] Change-Id: Ie4b9920b84e4643d13eaf0bb662b8b163125f7b3 --- so_version | 2 +- version | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/so_version b/so_version index 77605e74c7..8efd5969fe 100644 --- a/so_version +++ b/so_version @@ -1,2 +1,2 @@ 3 -1.2 +2.0 diff --git a/version b/version index ef538c2810..944880fa15 100644 --- a/version +++ b/version @@ -1 +1 @@ -3.1.2 +3.2.0 From 0fd0a3dd5a1f196356f06d1d293a741f72208f8c Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Mon, 13 Jun 2022 09:52:45 +0530 Subject: [PATCH 62/63] BugFix of AOCL_DYNAMIC in TRSM multithreaded small. - Added initialization of rntm object before aocl_dynamic. - Bugfixes in dtrsm right-side kernels, avoided accessing extra memory while using store for corner cases. AMD-Internal: [CPUPL-2193] [CPUPL-2194] Change-Id: I1c9d10edda93621626957d4de2f53d249ad531ba --- kernels/zen/3/bli_trsm_small.c | 271 +++++++++------------------------ 1 file changed, 76 insertions(+), 195 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index d7192a062b..bb8a2e9cc5 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -3908,7 +3908,6 @@ err_t bli_trsm_small_mt cntl_t* cntl ) { - rntm_t rntm; gint_t m = bli_obj_length( b ); // number of rows of matrix b gint_t n = bli_obj_width( b ); // number of columns of Matrix b dim_t d_mr = 8,d_nr = 6; @@ -3928,6 +3927,9 @@ err_t bli_trsm_small_mt } } + rntm_t rntm; + bli_rntm_init_from_global( &rntm ); + #ifdef AOCL_DYNAMIC // If dynamic-threading is enabled, calculate optimum number // of threads. @@ -3938,8 +3940,6 @@ err_t bli_trsm_small_mt } #endif - bli_rntm_init_from_global( &rntm ); - // Query the total number of threads from the rntm_t object. dim_t n_threads = bli_rntm_num_threads( &rntm ); @@ -6727,25 +6727,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b*4 + 2), _mm256_extractf128_pd(ymm11,1)); + _mm_storel_pd((double *)(b11 + cs_b*5 + 2), _mm256_extractf128_pd(ymm13,1)); m_remainder -= 3; i += 3; @@ -6857,25 +6851,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -= 2; i += 2; @@ -6987,25 +6968,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -= 1; i += 1; @@ -7397,23 +7365,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); m_remainder -= 3; i += 3; @@ -7494,21 +7454,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); m_remainder -= 2; i += 2; @@ -7588,15 +7537,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); @@ -9165,25 +9105,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); + _mm_storel_pd((double *)(b11 + cs_b*4 + 2), _mm256_extractf128_pd(ymm11,1)); + _mm_storel_pd((double *)(b11 + cs_b*5 + 2), _mm256_extractf128_pd(ymm13,1)); m_remainder -=3; } @@ -9286,25 +9220,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storeu_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -=2; } @@ -9407,25 +9328,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); + _mm_storel_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storel_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storel_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storel_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storel_pd((double *)(b11 + cs_b*4), _mm256_extractf128_pd(ymm11,0)); + _mm_storel_pd((double *)(b11 + cs_b*5), _mm256_extractf128_pd(ymm13,0)); m_remainder -=1; } @@ -9806,23 +9714,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3,1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5,1)); + _mm_storel_pd((double *)(b11 + cs_b*2 + 2), _mm256_extractf128_pd(ymm7,1)); + _mm_storel_pd((double *)(b11 + cs_b*3 + 2), _mm256_extractf128_pd(ymm9,1)); m_remainder -=3; } @@ -9898,21 +9798,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storeu_pd((double *)b11, _mm256_extractf128_pd(ymm3,0)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_extractf128_pd(ymm5,0)); + _mm_storeu_pd((double *)(b11 + cs_b*2), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(b11 + cs_b*3), _mm256_extractf128_pd(ymm9,0)); m_remainder -=2; } @@ -9985,15 +9874,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); From 1a63b9f0f9b82d76ae126fde0f86dd392b18fefc Mon Sep 17 00:00:00 2001 From: Dipal M Zambare Date: Tue, 14 Jun 2022 08:30:51 +0530 Subject: [PATCH 63/63] Fixed high impact static analysis issues Initialized ymm and xmm registers to zero to address un-inilizaed variable errors reported in static analsys. AMD-Internal: [CPUPL-2078] Change-Id: Icfcc008a0f244278efd8145d7feef764ed5fcc04 --- kernels/zen/3/bli_dgemm_ref_k1.c | 6 +++++- kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c | 8 +++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/kernels/zen/3/bli_dgemm_ref_k1.c b/kernels/zen/3/bli_dgemm_ref_k1.c index 659975cdb7..03a2b789bb 100644 --- a/kernels/zen/3/bli_dgemm_ref_k1.c +++ b/kernels/zen/3/bli_dgemm_ref_k1.c @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2021, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -394,6 +394,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); @@ -690,6 +691,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); @@ -897,6 +899,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); @@ -1052,6 +1055,7 @@ void bli_dgemm_ref_k1_nn if(m_rem == 1) { + ymm0 = _mm256_setzero_pd(); ymm3 = _mm256_setzero_pd(); ymm15 = _mm256_setzero_pd(); diff --git a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c index a21c9b5ed1..77f0348561 100644 --- a/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c +++ b/kernels/zen/3/sup/bli_gemmsup_rv_zen_asm_c3x8n.c @@ -6,7 +6,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2020-2021, Advanced Micro Devices, Inc. + Copyright (C) 2020-2022, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -138,6 +138,8 @@ void bli_cgemmsup_rv_zen_asm_3x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); @@ -572,6 +574,8 @@ void bli_cgemmsup_rv_zen_asm_2x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps(); @@ -919,6 +923,8 @@ void bli_cgemmsup_rv_zen_asm_1x8n for (n_iter = 0; n_iter < n0 / 8; n_iter++) { // clear scratch registers. + xmm0 = _mm_setzero_ps(); + xmm3 = _mm_setzero_ps(); ymm4 = _mm256_setzero_ps(); ymm5 = _mm256_setzero_ps(); ymm6 = _mm256_setzero_ps();